1 module dtree.decision;
2 
3 import std.stdio : writefln;
4 
5 import mir.math : sum;
6 import mir.ndslice : ndarray, sliced, map, ipack, universal, slice, repeat, shape;
7 import numir : ones, zeros, zeros_like, Ndim, unsqueeze;
8 
9 
10 /// infomation for left/right decision used in Decision and Node
11 struct DecisionInfo {
12     size_t[] index;
13     double[] prediction;
14     double impurity;
15 }
16 
17 
18 /// regression desicion implementation used at Node
19 struct Regression {
20     DecisionInfo left, right;
21     double threshold;
22 
23     void fit(X, Xs, Ys, I)(X x, Xs xs, Ys ys, I index, size_t fid, size_t nPreds) {
24         this.threshold = x[fid];
25         auto lrys = zeros!double(ys.length, nPreds);
26         auto lmean = zeros!double(nPreds);
27         auto rmean = zeros!double(nPreds);
28         size_t li = 0, ri = ys.length-1;
29 
30         // TODO use mir-algorithm
31         foreach (i; index) {
32             if (xs[i][fid] > this.threshold) {
33                 this.right.index ~= [i];
34                 lrys[ri--][] = ys[i];
35                 rmean[] += ys[i];
36             } else {
37                 this.left.index ~= [i];
38                 lrys[li++][] = ys[i];
39                 lmean[] += ys[i];
40             }
41         }
42         // assert(li == ri);
43 
44         if (left.index.length > 0) lmean[] /= left.index.length;
45         if (right.index.length > 0) rmean[] /= right.index.length;
46         auto lys = lrys[0 .. li]; // ys[leftIndex.sliced];
47         auto rys = lrys[li .. $]; // ys[rightIndex.sliced];
48         for (size_t i = 0; i < ys[0].length; ++i) {
49             lys[0 .. $, i] -= lmean[i];
50             rys[0 .. $, i] -= rmean[i];
51         }
52         this.left.prediction = lmean.ndarray;
53         this.right.prediction = rmean.ndarray;
54         // FIXME dmd requires .slice while ldc2 does not
55         this.left.impurity = (lys ^^ 2.0).slice.sum!"fast";
56         this.right.impurity = (rys ^^ 2.0).slice.sum!"fast";
57     }
58 }
59 
60 ///
61 unittest {
62     import mir.ndslice;
63     import numir;
64     import dtree.impurity;
65     import std.stdio;
66 
67     Regression c;
68     auto xs = [-2.0, -1.0, 0.0, 1.0].sliced.unsqueeze!1;
69     auto ys = [-1.0, 0.0, 1.0, 2.0].sliced.unsqueeze!1;
70     c.fit(xs[2], xs, ys, iota(4), 0, 1);
71 
72     assert(c.left.index == [0, 1, 2]);
73     assert(c.right.index == [3]);
74     assert(c.threshold == xs[2][0]);
75     assert(c.left.prediction == [0]);  // mean(-1, 0, 1)
76     assert(c.right.prediction == [2]); // mean(2)
77     assert(c.left.impurity == 2.0);  // ((-1-0)^2+(0-0)^2+(1-0)^2
78     assert(c.right.impurity == 0.0); // (2-2)^2
79 }
80 
81 auto normalizeProb(T)(T probs) pure {
82     import mir.math : sum;
83     auto psum = probs.sum!"fast";
84     return psum > 0.0 ? probs / psum : probs / 1.0;
85 }
86 
87 
88 /// classification desicion implementation used at Node
89 struct Classification(alias ImpurityFun) {
90     DecisionInfo left, right;
91     double threshold;
92 
93     void fit(X, Xs, Ys, I)(X x, Xs xs, Ys ys, I index, size_t fid, size_t nPreds) {
94         threshold = x[fid];
95         auto lpreds = zeros!double(nPreds);
96         auto rpreds = zeros!double(nPreds);
97         // TODO use mir-algorithm
98         foreach (i; index) {
99             if (xs[i][fid] > threshold) {
100                 right.index ~= [i];
101                 ++rpreds[ys[i]];
102             } else {
103                 left.index ~= [i];
104                 ++lpreds[ys[i]];
105             }
106         }
107         lpreds[] = lpreds.normalizeProb;
108         rpreds[] = rpreds.normalizeProb;
109         left.prediction = lpreds.ndarray;
110         right.prediction = rpreds.ndarray;
111         left.impurity = ImpurityFun(lpreds) * left.index.length;
112         right.impurity = ImpurityFun(rpreds) * right.index.length;
113     }
114 }
115 
116 ///
117 unittest {
118     import mir.ndslice : sliced, iota;
119     import numir : unsqueeze;
120     import dtree.impurity : gini, entropy;
121 
122     {
123         Classification!entropy c;
124         auto xs = [-2.0, -1.0, 0.0, 1.0].sliced.unsqueeze!1;
125         // for dmd ys needs to be long or size_t
126         auto ys = [0, 0, 0, 1].sliced!long;
127         c.fit(xs[2], xs, ys, iota(ys.length), 0, 2);
128 
129         assert(c.left.index == [0, 1, 2]);
130         assert(c.right.index == [3]);
131         assert(c.threshold == xs[2][0]);
132         assert(c.left.prediction == [1, 0]);
133         assert(c.right.prediction == [0, 1]);
134         assert(c.left.impurity == 0.0);
135         assert(c.right.impurity == 0.0);
136     }
137     {
138         Classification!gini c;
139         auto xs = [-2.0, -1.0, 0.0, 1.0].sliced.unsqueeze!1;
140         auto ys = [0, 0, 0, 1].sliced!long;
141         c.fit(xs[2], xs, ys, iota(ys.length), 0, 2);
142 
143         assert(c.left.index == [0, 1, 2]);
144         assert(c.right.index == [3]);
145         assert(c.threshold == xs[2][0]);
146         assert(c.left.prediction == [1, 0]);
147         assert(c.right.prediction == [0, 1]);
148         assert(c.left.impurity == 0.0);
149         assert(c.right.impurity == 0.0);
150     }
151 }