1 module dtree.boosting; 2 3 // TODO add this to numir 4 @nogc auto mean(Result=double, Xs)(Xs xs) pure { 5 import mir.math.sum : sum; 6 import mir.ndslice.topology : as; 7 import numir : size; 8 return xs.as!Result.sum!"fast" / xs.size; 9 } 10 11 auto mean(ptrdiff_t axis, Result=double, Xs)(Xs xs) pure { 12 import mir.ndslice : ipack, swapped, shape, each; 13 import numir : size, zeros; 14 auto xt = xs.swapped!(0, axis); 15 auto ret = zeros!Result(xt[0].shape); 16 xt.ipack!1.each!((x) { 17 ret[] += x; 18 }); 19 ret[] /= xs.length!axis; 20 return ret; 21 } 22 23 pure @safe 24 unittest { 25 import mir.ndslice : as, iota; 26 /* 27 [[0,1,2], 28 [3,4,5]] 29 */ 30 assert(iota(2, 3).mean == (5.0 / 2.0)); 31 assert(iota(2, 3).mean!0 == [(0.0+3.0)/2.0, (1.0+4.0)/2.0, (2.0+5.0)/2.0]); 32 assert(iota(2, 3).mean!1 == [(0.0+1.0+2.0)/3.0, (3.0+4.0+5.0)/3.0]); 33 } 34 35 36 auto mseGrad(T, P)(T target, P pred) { 37 return target - pred; 38 } 39 40 struct GradientBoosting(Model, alias gradient=mseGrad) { 41 Model bluePrint; 42 size_t nBoost; 43 Model[] models; 44 double[] initPred; 45 double stepSize = -1e-2; 46 47 auto fit(Xs, Ys)(Xs xs, Ys ys) { 48 import numir : zeros_like; 49 import std.range : enumerate; 50 import mir.ndslice : slice, map, ipack, sliced, each; 51 import mir.ndslice.topology : repeat; 52 auto pred = zeros_like(ys); 53 auto grad = zeros_like(ys); 54 // TODO support multi dim 55 this.initPred = [ys.mean]; 56 pred[] = this.initPred[0]; 57 grad[] = -gradient(ys, pred); 58 models.length = nBoost; 59 foreach (ref m; models) { 60 m = bluePrint; // copy 61 m.fit(xs, grad); 62 // TODO implement line-search to find the best stepSize 63 auto nextPred = zeros_like(ys); 64 foreach(i, x; xs.enumerate) { 65 pred[i][] += this.stepSize * m.predict(x).sliced; 66 } 67 grad[] = -gradient(ys, pred); 68 } 69 } 70 71 auto predict(X)(X x) { 72 import mir.ndslice : sliced; 73 auto pred = this.initPred.sliced; 74 foreach (m; models) { 75 pred[] += this.stepSize * m.predict(x).sliced; 76 } 77 // TODO support multi dim 78 return pred; 79 } 80 }