module dtree.forest;

// import dtree.tree : ClassificationTree;
import mir.ndslice : iota, ndarray;
import numir : permutation, zeros;
import dtree.impurity : entropy;


struct RandomForest(Tree) {
    Tree initTree;
    size_t nTree = 5;
    bool bootstrap = true;
    Tree[] trees;

    void fit(Xs, Ys)(Xs xs, Ys ys) in {
        assert(xs.length == ys.length);
    } body {
        // TODO implement boot strap sampling
        // 1. sampling sample-id from multinomial dist
        // 2. sampling feature-id fromm multinomial dist
        import std.math : ceil;
        import std.algorithm : min;
        import std.conv : to;
        auto ps = ys.length.permutation;
        const stride = ys.length / this.nTree;
        for (size_t t = 0; t < this.nTree; ++t) {
            auto tree = initTree;
            auto a = t * stride;
            auto b = min(a + stride, ps.length);
            auto ab = ps[a .. b];
            tree.fit(xs[ab], ys[ab]);
            this.trees ~= [tree];
        }
    }

    auto predict(X)(X x) {
        import mir.ndslice : map, iota, sliced;
        import mir.ndslice.allocation : ndarray;
        import mir.math : sum;
        auto result = zeros(this.initTree.nOutput);
        foreach (t; this.trees) {
            result[] += t.predict(x).sliced;
        }
        result[] /= this.nTree; // result.sum!"fast";
        return result.ndarray;
    }
}


auto toRandomForest(Tree)(Tree tree, size_t nTree) {
    return RandomForest!(Tree)(tree, nTree);
}