DTreePlain#
Currently, this API is unsupported in Python
-
class DTreePlain : public helayers::PlainModel#
A class representing a plain decision tree model.
A decision tree is a binary tree whose internal nodes hold comparison conditions of the kind “feature < threshold” and its leaves hold output values. An input sample starts with the tree’s root, proceeds to the left child for each true comparison condition and to the right child for each false condition. The output of the decision tree is the output value of the leaf reached by the input sample. Currently, we only support decision trees for binary classification problems. input shape: [batch, numFeatures]. output shape: [1, batch]. The i’th element is the output value of the leaf reached by the i’th input sample.
Public Functions
-
inline DTreePlain()#
A constructor.
-
void init(const PlainModelHyperParams &hyperParams, const std::vector<int> &leftChildren, const std::vector<int> &rightChildren, const std::vector<int> &splitIndices, const std::vector<double> &splitConditions)#
Initializes this DTreePlain with the given parameters.
- Parameters:
hyperParams – Model hyperparameters.
leftChildren – An array whose i-th element contains the left child of the i-th node. If the i-th node is a leaf, then leftChildren[i] should be -1.
rightChildren – An array whose i-th element contains the right child of the i-th node. If the i-th node is a leaf, then rightChildren[i] should be -1.
splitIndices – An array whose i-th element contains the index of the feature used in the i-th node. If the i-th node is a leaf, then splitIndices[i] should be -1.
splitConditions – If the i-th node is an internal node (i.e. not a leaf), then splitConditions[i] should contain the threshold used in the i-th node. Otherwise, splitConditions[i] should contain the output value of the i-th node.
-
virtual std::shared_ptr<HeModel> getEmptyHeModel(const HeContext &he) const override#
Returns an empty (not initialized) HE model from the underlying type corresponding to this plain model.
- Parameters:
he – The HE context.
-
inline virtual std::string getClassName() const override#
Returns the name of this class.
-
virtual void debugPrint(const std::string &title = "", Verbosity verbosity = VERBOSITY_REGULAR, std::ostream &out = std::cout) const override#
Prints the content of this object.
- Parameters:
title – Text to add to the print
verbosity – Verbosity level
out – Output stream
-
virtual std::vector<PlainTensorMetadata> getInputsPlainTensorMetadata() const override#
Returns a vector of PlainTensorMetadata objects.
The i-th element of this vector contains metadata relating to the i-th input of this PlainModel (such as shape and batch dimension). If this PlainModel is initialized for prediction, the returned vector describes inputs for the the predict() method. If this PlainModel is initialized for fitting, the returned vector describes the inputs for the fit() method.
Stores histogram internally, and records tree’s node inside it.
- Parameters:
plainHist – The histogram
-
std::string getNodeCode(int index) const#
Returns the code associated with this node according to plain cache.
Returns “” if no histogram was attached to this tree, or no code associated with this node.
- Parameters:
index – Index of node in tree.
-
inline DTreePlain()#