1 import mir.ndslice : sliced, iota; 2 import numir : unsqueeze; 3 import dtree.impurity : gini, entropy; 4 5 { 6 Classification!entropy c; 7 auto xs = [-2.0, -1.0, 0.0, 1.0].sliced.unsqueeze!1; 8 // for dmd ys needs to be long or size_t 9 auto ys = [0, 0, 0, 1].sliced!long; 10 c.fit(xs[2], xs, ys, iota(ys.length), 0, 2); 11 12 assert(c.left.index == [0, 1, 2]); 13 assert(c.right.index == [3]); 14 assert(c.threshold == xs[2][0]); 15 assert(c.left.prediction == [1, 0]); 16 assert(c.right.prediction == [0, 1]); 17 assert(c.left.impurity == 0.0); 18 assert(c.right.impurity == 0.0); 19 } 20 { 21 Classification!gini c; 22 auto xs = [-2.0, -1.0, 0.0, 1.0].sliced.unsqueeze!1; 23 auto ys = [0, 0, 0, 1].sliced!long; 24 c.fit(xs[2], xs, ys, iota(ys.length), 0, 2); 25 26 assert(c.left.index == [0, 1, 2]); 27 assert(c.right.index == [3]); 28 assert(c.threshold == xs[2][0]); 29 assert(c.left.prediction == [1, 0]); 30 assert(c.right.prediction == [0, 1]); 31 assert(c.left.impurity == 0.0); 32 assert(c.right.impurity == 0.0); 33 }
classification desicion implementation used at Node