module dtree.tree;

import dtree.impurity : entropy;
import dtree.traits : isDecisionPolicy;
import dtree.node : Node;
import dtree.decision : Regression, Classification, DecisionInfo;

struct DecisionTree(DecisionPolicy) {
    static assert(isDecisionPolicy!DecisionPolicy);
    size_t nOutput = 2;
    size_t maxDepth = 5;
    size_t minElement = 0;
    Node* root;

    auto fit(Xs, Ys)(Xs xs, Ys ys) {
        void fitrec(Node* node) {
            if (node.depth >= this.maxDepth || node.nElement <= this.minElement) return;
            node.fit!DecisionPolicy(xs, ys);
            fitrec(node.left);
            fitrec(node.right);
        }

        import std.range : iota;
        import std.array : array;
        import numir : ones;
        import mir.ndslice : ndarray;

        auto points = iota(ys.length).array;
        auto initProb = ones(this.nOutput) * double.nan;
        this.root = new Node(0, DecisionInfo(points, initProb.ndarray, double.nan));
        fitrec(this.root);
    }

    auto predict(X)(X x) pure {
        return this.root.predict(x);
    }
}

alias ClassificationTree(alias ImpurityFun=entropy) = DecisionTree!(Classification!ImpurityFun);
alias RegressionTree = DecisionTree!Regression;