1 module dtree.node;
2 
3 import dtree.decision : DecisionInfo;
4 import mir.random : rand;
5 
6 struct Node {
7     size_t depth = 0;
8     DecisionInfo info;
9     alias info this;
10 
11     size_t bestFeatId = 0;
12     size_t bestSampleId = 0;
13     double bestThreshold = 0;
14 
15     typeof(this)* left, right;
16 
17     @property
18     const nElement() pure {
19         return index.length;
20     }
21 
22     @property
23     const isLeaf() pure {
24         return left is null && right is null;
25     }
26 
27     auto born(DecisionInfo info) {
28         return new typeof(this)(this.depth + 1, info);
29     }
30 
31     auto predict(X)(X x) pure {
32         if (this.isLeaf) { return this.prediction; }
33         auto next = x[this.bestFeatId] > this.bestThreshold ? right : left;
34         return next.predict(x);
35     }
36 
37     void fit(Decision, Xs, Ys)(Xs xs, Ys ys) in {
38         assert(this.index.length > 0);
39         assert(xs.length == ys.length);
40     } out {
41         import std.array : array;
42         import std.algorithm : sort;
43         assert(sort(this.left.index ~ this.right.index).array == this.index);
44     } do {
45         import std.math : isNaN;
46 
47         auto bestImpurity = double.nan;
48         Decision bestDecision;
49         // TODO support discrete feat
50         for (size_t fid = 0; fid < xs[0].length; ++fid) {
51             foreach (sid; this.index) {
52                 auto x = xs[sid];
53                 Decision decision;
54                 decision.fit(x, xs, ys, this.index, fid, this.prediction.length);
55                 auto imp = decision.left.impurity + decision.right.impurity;
56                 auto equalUpdate = imp == bestImpurity && rand!bool();
57                 if (bestImpurity.isNaN || imp < bestImpurity || equalUpdate) {
58                     bestDecision = decision;
59                     bestImpurity = imp;
60                     this.bestSampleId = sid;
61                     this.bestFeatId = fid;
62                     this.bestThreshold = decision.threshold;
63                 }
64             }
65         }
66 
67         with (bestDecision) {
68             import std.stdio : writefln;
69             writefln("depth: %d, impurity: %f, threshold %f, predict: %s, L: %d, R: %d, All: %d",
70                      this.depth, this.impurity, this.bestThreshold, this.prediction,
71                      left.index.length, right.index.length, this.index.length);
72             this.left = this.born(left);
73             this.right = this.born(right);
74         }
75     }
76 }