1 import mir.ndslice; 2 import numir; 3 import dtree.impurity; 4 import std.stdio; 5 6 Regression c; 7 auto xs = [-2.0, -1.0, 0.0, 1.0].sliced.unsqueeze!1; 8 auto ys = [-1.0, 0.0, 1.0, 2.0].sliced.unsqueeze!1; 9 c.fit(xs[2], xs, ys, iota(4), 0, 1); 10 11 assert(c.left.index == [0, 1, 2]); 12 assert(c.right.index == [3]); 13 assert(c.threshold == xs[2][0]); 14 assert(c.left.prediction == [0]); // mean(-1, 0, 1) 15 assert(c.right.prediction == [2]); // mean(2) 16 assert(c.left.impurity == 2.0); // ((-1-0)^2+(0-0)^2+(1-0)^2 17 assert(c.right.impurity == 0.0); // (2-2)^2
regression desicion implementation used at Node