1 module dtree.decision; 2 3 import std.stdio : writefln; 4 5 import mir.math : sum; 6 import mir.ndslice : ndarray, sliced, map, ipack, universal, slice, repeat, shape; 7 import numir : ones, zeros, zeros_like, Ndim, unsqueeze; 8 9 10 /// infomation for left/right decision used in Decision and Node 11 struct DecisionInfo { 12 size_t[] index; 13 double[] prediction; 14 double impurity; 15 } 16 17 18 /// regression desicion implementation used at Node 19 struct Regression { 20 DecisionInfo left, right; 21 double threshold; 22 23 void fit(X, Xs, Ys, I)(X x, Xs xs, Ys ys, I index, size_t fid, size_t nPreds) { 24 this.threshold = x[fid]; 25 auto lrys = zeros!double(ys.length, nPreds); 26 auto lmean = zeros!double(nPreds); 27 auto rmean = zeros!double(nPreds); 28 size_t li = 0, ri = ys.length-1; 29 30 // TODO use mir-algorithm 31 foreach (i; index) { 32 if (xs[i][fid] > this.threshold) { 33 this.right.index ~= [i]; 34 lrys[ri--][] = ys[i]; 35 rmean[] += ys[i]; 36 } else { 37 this.left.index ~= [i]; 38 lrys[li++][] = ys[i]; 39 lmean[] += ys[i]; 40 } 41 } 42 // assert(li == ri); 43 44 if (left.index.length > 0) lmean[] /= left.index.length; 45 if (right.index.length > 0) rmean[] /= right.index.length; 46 auto lys = lrys[0 .. li]; // ys[leftIndex.sliced]; 47 auto rys = lrys[li .. $]; // ys[rightIndex.sliced]; 48 for (size_t i = 0; i < ys[0].length; ++i) { 49 lys[0 .. $, i] -= lmean[i]; 50 rys[0 .. $, i] -= rmean[i]; 51 } 52 this.left.prediction = lmean.ndarray; 53 this.right.prediction = rmean.ndarray; 54 // FIXME dmd requires .slice while ldc2 does not 55 this.left.impurity = (lys ^^ 2.0).slice.sum!"fast"; 56 this.right.impurity = (rys ^^ 2.0).slice.sum!"fast"; 57 } 58 } 59 60 /// 61 unittest { 62 import mir.ndslice; 63 import numir; 64 import dtree.impurity; 65 import std.stdio; 66 67 Regression c; 68 auto xs = [-2.0, -1.0, 0.0, 1.0].sliced.unsqueeze!1; 69 auto ys = [-1.0, 0.0, 1.0, 2.0].sliced.unsqueeze!1; 70 c.fit(xs[2], xs, ys, iota(4), 0, 1); 71 72 assert(c.left.index == [0, 1, 2]); 73 assert(c.right.index == [3]); 74 assert(c.threshold == xs[2][0]); 75 assert(c.left.prediction == [0]); // mean(-1, 0, 1) 76 assert(c.right.prediction == [2]); // mean(2) 77 assert(c.left.impurity == 2.0); // ((-1-0)^2+(0-0)^2+(1-0)^2 78 assert(c.right.impurity == 0.0); // (2-2)^2 79 } 80 81 auto normalizeProb(T)(T probs) pure { 82 import mir.math : sum; 83 auto psum = probs.sum!"fast"; 84 return psum > 0.0 ? probs / psum : probs / 1.0; 85 } 86 87 88 /// classification desicion implementation used at Node 89 struct Classification(alias ImpurityFun) { 90 DecisionInfo left, right; 91 double threshold; 92 93 void fit(X, Xs, Ys, I)(X x, Xs xs, Ys ys, I index, size_t fid, size_t nPreds) { 94 threshold = x[fid]; 95 auto lpreds = zeros!double(nPreds); 96 auto rpreds = zeros!double(nPreds); 97 // TODO use mir-algorithm 98 foreach (i; index) { 99 if (xs[i][fid] > threshold) { 100 right.index ~= [i]; 101 ++rpreds[ys[i]]; 102 } else { 103 left.index ~= [i]; 104 ++lpreds[ys[i]]; 105 } 106 } 107 lpreds[] = lpreds.normalizeProb; 108 rpreds[] = rpreds.normalizeProb; 109 left.prediction = lpreds.ndarray; 110 right.prediction = rpreds.ndarray; 111 left.impurity = ImpurityFun(lpreds) * left.index.length; 112 right.impurity = ImpurityFun(rpreds) * right.index.length; 113 } 114 } 115 116 /// 117 unittest { 118 import mir.ndslice : sliced, iota; 119 import numir : unsqueeze; 120 import dtree.impurity : gini, entropy; 121 122 { 123 Classification!entropy c; 124 auto xs = [-2.0, -1.0, 0.0, 1.0].sliced.unsqueeze!1; 125 // for dmd ys needs to be long or size_t 126 auto ys = [0, 0, 0, 1].sliced!long; 127 c.fit(xs[2], xs, ys, iota(ys.length), 0, 2); 128 129 assert(c.left.index == [0, 1, 2]); 130 assert(c.right.index == [3]); 131 assert(c.threshold == xs[2][0]); 132 assert(c.left.prediction == [1, 0]); 133 assert(c.right.prediction == [0, 1]); 134 assert(c.left.impurity == 0.0); 135 assert(c.right.impurity == 0.0); 136 } 137 { 138 Classification!gini c; 139 auto xs = [-2.0, -1.0, 0.0, 1.0].sliced.unsqueeze!1; 140 auto ys = [0, 0, 0, 1].sliced!long; 141 c.fit(xs[2], xs, ys, iota(ys.length), 0, 2); 142 143 assert(c.left.index == [0, 1, 2]); 144 assert(c.right.index == [3]); 145 assert(c.threshold == xs[2][0]); 146 assert(c.left.prediction == [1, 0]); 147 assert(c.right.prediction == [0, 1]); 148 assert(c.left.impurity == 0.0); 149 assert(c.right.impurity == 0.0); 150 } 151 }