1 module dtree.node; 2 3 import dtree.decision : DecisionInfo; 4 import mir.random : rand; 5 6 struct Node { 7 size_t depth = 0; 8 DecisionInfo info; 9 alias info this; 10 11 size_t bestFeatId = 0; 12 size_t bestSampleId = 0; 13 double bestThreshold = 0; 14 15 typeof(this)* left, right; 16 17 @property 18 const nElement() pure { 19 return index.length; 20 } 21 22 @property 23 const isLeaf() pure { 24 return left is null && right is null; 25 } 26 27 auto born(DecisionInfo info) { 28 return new typeof(this)(this.depth + 1, info); 29 } 30 31 auto predict(X)(X x) pure { 32 if (this.isLeaf) { return this.prediction; } 33 auto next = x[this.bestFeatId] > this.bestThreshold ? right : left; 34 return next.predict(x); 35 } 36 37 void fit(Decision, Xs, Ys)(Xs xs, Ys ys) in { 38 assert(this.index.length > 0); 39 assert(xs.length == ys.length); 40 } out { 41 import std.array : array; 42 import std.algorithm : sort; 43 assert(sort(this.left.index ~ this.right.index).array == this.index); 44 } do { 45 import std.math : isNaN; 46 47 auto bestImpurity = double.nan; 48 Decision bestDecision; 49 // TODO support discrete feat 50 for (size_t fid = 0; fid < xs[0].length; ++fid) { 51 foreach (sid; this.index) { 52 auto x = xs[sid]; 53 Decision decision; 54 decision.fit(x, xs, ys, this.index, fid, this.prediction.length); 55 auto imp = decision.left.impurity + decision.right.impurity; 56 auto equalUpdate = imp == bestImpurity && rand!bool(); 57 if (bestImpurity.isNaN || imp < bestImpurity || equalUpdate) { 58 bestDecision = decision; 59 bestImpurity = imp; 60 this.bestSampleId = sid; 61 this.bestFeatId = fid; 62 this.bestThreshold = decision.threshold; 63 } 64 } 65 } 66 67 with (bestDecision) { 68 import std.stdio : writefln; 69 writefln("depth: %d, impurity: %f, threshold %f, predict: %s, L: %d, R: %d, All: %d", 70 this.depth, this.impurity, this.bestThreshold, this.prediction, 71 left.index.length, right.index.length, this.index.length); 72 this.left = this.born(left); 73 this.right = this.born(right); 74 } 75 } 76 }