From 786bf6cd656d0d67e56bf50047ff116bae884b9e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Wed, 12 Jul 2017 07:37:31 -0700 Subject: [PATCH] Refactor some of TensorForest V4 to make the tree model valid during training time, instead of only after FinalizeTreeOp. PiperOrigin-RevId: 161663317 --- .../tensor_forest/kernels/model_ops.cc | 183 ++++++++++++++++-- .../tensor_forest/kernels/model_ops_test.cc | 27 +++ .../tensor_forest/kernels/stats_ops.cc | 101 ++++------ .../tensor_forest/kernels/stats_ops_test.cc | 2 +- .../kernels/v4/decision-tree-resource.cc | 9 +- .../kernels/v4/decision-tree-resource.h | 8 +- .../kernels/v4/fertile-stats-resource.cc | 26 +-- .../kernels/v4/fertile-stats-resource.h | 11 +- .../kernels/v4/leaf_model_operators.cc | 33 ++-- .../kernels/v4/leaf_model_operators.h | 23 ++- .../kernels/v4/leaf_model_operators_test.cc | 13 +- .../kernels/v4/split_collection_operators.cc | 10 +- .../kernels/v4/split_collection_operators.h | 8 + .../contrib/tensor_forest/ops/model_ops.cc | 52 +++++ .../contrib/tensor_forest/ops/stats_ops.cc | 2 + .../tensor_forest/python/ops/model_ops.py | 11 +- .../tensor_forest/python/tensor_forest_v4.py | 27 ++- 17 files changed, 374 insertions(+), 172 deletions(-) diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index 0f92f05e2c0..221f8d969bc 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= +#include <functional> #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" @@ -26,6 +27,7 @@ #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace tensorforest { @@ -46,7 +48,7 @@ class CreateTreeVariableOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()), errors::InvalidArgument("Tree config must be a scalar.")); - auto* result = new DecisionTreeResource(); + auto* result = new DecisionTreeResource(param_proto_); if (!ParseProtoUnlimited(result->mutable_decision_tree(), tree_config_t->scalar<string>()())) { result->Unref(); @@ -142,6 +144,16 @@ class TreeSizeOp : public OpKernel { } }; +void TraverseTree(const DecisionTreeResource* tree_resource, + const std::unique_ptr<TensorDataSet>& data, int32 start, + int32 end, + const std::function<void(int32, int32)>& set_leaf_id) { + for (int i = start; i < end; ++i) { + const int32 id = tree_resource->TraverseTree(data, i, nullptr); + set_leaf_id(i, id); + } +} + // Op for tree inference. class TreePredictionsV4Op : public OpKernel { public: @@ -176,22 +188,49 @@ class TreePredictionsV4Op : public OpKernel { mutex_lock l(*decision_tree_resource->get_mutex()); core::ScopedUnref unref_me(decision_tree_resource); + const int num_data = data_set_->NumItems(); + const int32 num_outputs = param_proto_.num_outputs(); + Tensor* output_predictions = nullptr; TensorShape output_shape; - output_shape.AddDim(data_set_->NumItems()); - output_shape.AddDim(param_proto_.num_outputs()); + output_shape.AddDim(num_data); + output_shape.AddDim(num_outputs); OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_predictions)); + TTypes<float, 2>::Tensor out = output_predictions->tensor<float, 2>(); - auto out = output_predictions->tensor<float, 2>(); - for (int i = 0; i < data_set_->NumItems(); ++i) { - const int32 leaf_id = - decision_tree_resource->TraverseTree(data_set_, i, nullptr); - const decision_trees::Leaf& leaf = - decision_tree_resource->get_leaf(leaf_id); + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + int num_threads = worker_threads->num_threads; + const int64 costPerTraverse = 500; + auto traverse = [this, &out, decision_tree_resource, num_data](int64 start, + int64 end) { + CHECK(start <= end); + CHECK(end <= num_data); + TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start), + static_cast<int32>(end), + std::bind(&TreePredictionsV4Op::set_output_value, this, + std::placeholders::_1, std::placeholders::_2, + decision_tree_resource, &out)); + }; + Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, + traverse); + } + + void set_output_value(int32 i, int32 id, + DecisionTreeResource* decision_tree_resource, + TTypes<float, 2>::Tensor* out) { + const decision_trees::Leaf& leaf = decision_tree_resource->get_leaf(id); + + float sum = 0; + for (int j = 0; j < param_proto_.num_outputs(); ++j) { + const float count = model_op_->GetOutputValue(leaf, j); + (*out)(i, j) = count; + sum += count; + } + + if (!param_proto_.is_regression() && sum > 0 && sum != 1) { for (int j = 0; j < param_proto_.num_outputs(); ++j) { - const float count = model_op_->GetOutputValue(leaf, j); - out(i, j) = count; + (*out)(i, j) /= sum; } } } @@ -203,6 +242,122 @@ class TreePredictionsV4Op : public OpKernel { TensorForestParams param_proto_; }; +// Outputs leaf ids for the given examples. +class TraverseTreeV4Op : public OpKernel { + public: + explicit TraverseTreeV4Op(OpKernelConstruction* context) : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + + string serialized_proto; + OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); + input_spec_.ParseFromString(serialized_proto); + + data_set_ = + std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_data = context->input(1); + const Tensor& sparse_input_indices = context->input(2); + const Tensor& sparse_input_values = context->input(3); + const Tensor& sparse_input_shape = context->input(4); + + data_set_->set_input_tensors(input_data, sparse_input_indices, + sparse_input_values, sparse_input_shape); + + DecisionTreeResource* decision_tree_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &decision_tree_resource)); + mutex_lock l(*decision_tree_resource->get_mutex()); + core::ScopedUnref unref_me(decision_tree_resource); + + const int num_data = data_set_->NumItems(); + + Tensor* output_predictions = nullptr; + TensorShape output_shape; + output_shape.AddDim(num_data); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, + &output_predictions)); + + auto leaf_ids = output_predictions->tensor<int32, 1>(); + + auto set_leaf_ids = [&leaf_ids](int32 i, int32 id) { leaf_ids(i) = id; }; + + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + int num_threads = worker_threads->num_threads; + const int64 costPerTraverse = 500; + auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data]( + int64 start, int64 end) { + CHECK(start <= end); + CHECK(end <= num_data); + TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start), + static_cast<int32>(end), set_leaf_ids); + }; + Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, + traverse); + } + + private: + tensorforest::TensorForestDataSpec input_spec_; + std::unique_ptr<TensorDataSet> data_set_; + TensorForestParams param_proto_; +}; + +// Update the given leaf models using the batch of labels. +class UpdateModelV4Op : public OpKernel { + public: + explicit UpdateModelV4Op(OpKernelConstruction* context) : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + + model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_); + } + + void Compute(OpKernelContext* context) override { + const Tensor& leaf_ids = context->input(1); + const Tensor& input_labels = context->input(2); + const Tensor& input_weights = context->input(3); + + DecisionTreeResource* decision_tree_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &decision_tree_resource)); + mutex_lock l(*decision_tree_resource->get_mutex()); + core::ScopedUnref unref_me(decision_tree_resource); + + const int num_data = input_labels.shape().dim_size(0); + const int32 label_dim = + input_labels.shape().dims() <= 1 + ? 0 + : static_cast<int>(input_labels.shape().dim_size(1)); + const int32 num_targets = + param_proto_.is_regression() ? (std::max(1, label_dim)) : 1; + + TensorInputTarget target(input_labels, input_weights, num_targets); + + // TODO(gilberth): Make this thread safe and multi-thread. + UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource); + } + + void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target, + int32 start, int32 end, + DecisionTreeResource* decision_tree_resource) { + const auto leaves = leaf_ids.unaligned_flat<int32>(); + for (int i = start; i < end; ++i) { + model_op_->UpdateModel( + decision_tree_resource->get_mutable_tree_node(leaves(i)) + ->mutable_leaf(), + &target, i); + } + } + + private: + std::unique_ptr<LeafModelOperator> model_op_; + TensorForestParams param_proto_; +}; + // Op for getting feature usage counts. class FeatureUsageCountsOp : public OpKernel { public: @@ -286,8 +441,14 @@ REGISTER_KERNEL_BUILDER(Name("TreeSize").Device(DEVICE_CPU), TreeSizeOp); REGISTER_KERNEL_BUILDER(Name("TreePredictionsV4").Device(DEVICE_CPU), TreePredictionsV4Op); +REGISTER_KERNEL_BUILDER(Name("TraverseTreeV4").Device(DEVICE_CPU), + TraverseTreeV4Op); + REGISTER_KERNEL_BUILDER(Name("FeatureUsageCounts").Device(DEVICE_CPU), FeatureUsageCountsOp); +REGISTER_KERNEL_BUILDER(Name("UpdateModelV4").Device(DEVICE_CPU), + UpdateModelV4Op); + } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc index cece61a54c5..0fdab8e6e0b 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc @@ -64,6 +64,33 @@ TEST(ModelOpsTest, TreePredictionsV4_ShapeFn) { INFER_OK(op, "?;?;?;?;[10,11]", "[?,?]"); } +TEST(ModelOpsTest, TraverseTreeV4_ShapeFn) { + ShapeInferenceTestOp op("TraverseTreeV4"); + TF_ASSERT_OK(NodeDefBuilder("test", "TraverseTreeV4") + .Input("a", 0, DT_RESOURCE) + .Input("b", 1, DT_FLOAT) + .Input("c", 2, DT_INT64) + .Input("d", 3, DT_FLOAT) + .Input("e", 5, DT_INT64) + .Attr("input_spec", "") + .Attr("params", "") + .Finalize(&op.node_def)); + + // num_points = 2, sparse shape not known + INFER_OK(op, "?;[2,3];?;?;?", "[d1_0]"); + + // num_points = 2, sparse and dense shape rank known and > 1 + INFER_OK(op, "?;[2,3];?;?;[10,11]", "[d1_0]"); + + // num_points = 2, sparse shape rank known and > 1 + INFER_OK(op, "?;?;?;?;[10,11]", "[?]"); +} + +TEST(ModelOpsTest, UpdateModelV4_ShapeFn) { + ShapeInferenceTestOp op("UpdateModelV4"); + INFER_OK(op, "[1];?;?;?", ""); +} + TEST(ModelOpsTest, FeatureUsageCounts_ShapeFn) { ShapeInferenceTestOp op("FeatureUsageCounts"); INFER_OK(op, "[1]", "[?]"); diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index 260e03df262..b6d57ef9527 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -141,18 +141,6 @@ class FertileStatsDeserializeOp : public OpKernel { TensorForestParams param_proto_; }; -void TraverseTree(const DecisionTreeResource* tree_resource, - const std::unique_ptr<TensorDataSet>& data, int32 start, - int32 end, std::vector<int32>* leaf_ids, - std::vector<int32>* leaf_depths) { - for (int i = start; i < end; ++i) { - int32 depth; - const int32 leaf_id = tree_resource->TraverseTree(data, i, &depth); - (*leaf_ids)[i] = leaf_id; - (*leaf_depths)[i] = depth; - } -} - // Try to update a leaf's stats by acquiring its lock. If it can't be // acquired, put it in a waiting queue to come back to later and try the next // one. Once all leaf_ids have been visited, cycle through the waiting ids @@ -160,28 +148,27 @@ void TraverseTree(const DecisionTreeResource* tree_resource, void UpdateStats(FertileStatsResource* fertile_stats_resource, const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target, int num_targets, - const std::vector<int32>& leaf_ids, - const std::vector<int32>& leaf_depths, + const Tensor& leaf_ids_tensor, std::unordered_map<int32, std::unique_ptr<mutex>>* locks, mutex* set_lock, int32 start, int32 end, std::unordered_set<int32>* ready_to_split) { + const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>(); + // Stores leaf_id, leaf_depth, example_id for examples that are waiting // on another to finish. - std::queue<std::tuple<int32, int32, int32>> waiting; + std::queue<std::tuple<int32, int32>> waiting; int32 i = start; while (i < end || !waiting.empty()) { int32 leaf_id; - int32 leaf_depth; int32 example_id; bool was_waiting = false; if (i >= end) { - std::tie(leaf_id, leaf_depth, example_id) = waiting.front(); + std::tie(leaf_id, example_id) = waiting.front(); waiting.pop(); was_waiting = true; } else { - leaf_id = leaf_ids[i]; - leaf_depth = leaf_depths[i]; + leaf_id = leaf_ids(i); example_id = i; ++i; } @@ -190,14 +177,14 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource, leaf_lock->lock(); } else { if (!leaf_lock->try_lock()) { - waiting.emplace(leaf_id, leaf_depth, example_id); + waiting.emplace(leaf_id, example_id); continue; } } bool is_finished; fertile_stats_resource->AddExampleToStatsAndInitialize( - data, &target, {example_id}, leaf_id, leaf_depth, &is_finished); + data, &target, {example_id}, leaf_id, &is_finished); leaf_lock->unlock(); if (is_finished) { set_lock->lock(); @@ -214,8 +201,8 @@ void UpdateStatsCollated( const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target, int num_targets, const std::unordered_map<int32, std::vector<int>>& leaf_examples, - const std::vector<int32>& leaf_depths, mutex* set_lock, int32 start, - int32 end, std::unordered_set<int32>* ready_to_split) { + mutex* set_lock, int32 start, int32 end, + std::unordered_set<int32>* ready_to_split) { auto it = leaf_examples.begin(); std::advance(it, start); auto end_it = leaf_examples.begin(); @@ -224,8 +211,7 @@ void UpdateStatsCollated( int32 leaf_id = it->first; bool is_finished; fertile_stats_resource->AddExampleToStatsAndInitialize( - data, &target, it->second, leaf_id, leaf_depths[it->second[0]], - &is_finished); + data, &target, it->second, leaf_id, &is_finished); if (is_finished) { set_lock->lock(); ready_to_split->insert(leaf_id); @@ -261,6 +247,7 @@ class ProcessInputOp : public OpKernel { const Tensor& sparse_input_shape = context->input(5); const Tensor& input_labels = context->input(6); const Tensor& input_weights = context->input(7); + const Tensor& leaf_ids_tensor = context->input(8); data_set_->set_input_tensors(input_data, sparse_input_indices, sparse_input_values, sparse_input_shape); @@ -281,22 +268,7 @@ class ProcessInputOp : public OpKernel { auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; - // First find the leaf ids for each example. - std::vector<int32> leaf_ids(num_data); - - // The depth of the leaf for example i. - std::vector<int32> leaf_depths(num_data); - - const int64 costPerTraverse = 500; - auto traverse = [this, &leaf_ids, &leaf_depths, tree_resource, num_data]( - int64 start, int64 end) { - CHECK(start <= end); - CHECK(end <= num_data); - TraverseTree(tree_resource, data_set_, static_cast<int32>(start), - static_cast<int32>(end), &leaf_ids, &leaf_depths); - }; - Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, - traverse); + const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>(); // Create one mutex per leaf. We need to protect access to leaf pointers, // so instead of grouping examples by leaf, we spread examples out among @@ -306,10 +278,11 @@ class ProcessInputOp : public OpKernel { std::unordered_map<int32, std::vector<int>> leaf_examples; if (param_proto_.collate_examples()) { for (int i = 0; i < num_data; ++i) { - leaf_examples[leaf_ids[i]].push_back(i); + leaf_examples[leaf_ids(i)].push_back(i); } } else { - for (const int32 id : leaf_ids) { + for (int i = 0; i < num_data; ++i) { + const int32 id = leaf_ids(i); if (FindOrNull(locks, id) == nullptr) { // TODO(gilberth): Consider using a memory pool for these. locks[id] = std::unique_ptr<mutex>(new mutex); @@ -335,27 +308,26 @@ class ProcessInputOp : public OpKernel { // from a digits run on local desktop. Heuristics might be necessary // if it really matters that much. const int64 costPerUpdate = 1000; - auto update = [this, &target, &leaf_ids, &leaf_depths, &num_targets, + auto update = [this, &target, &leaf_ids_tensor, &num_targets, fertile_stats_resource, &locks, &set_lock, &ready_to_split, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); UpdateStats(fertile_stats_resource, data_set_, target, num_targets, - leaf_ids, leaf_depths, &locks, &set_lock, - static_cast<int32>(start), static_cast<int32>(end), - &ready_to_split); + leaf_ids_tensor, &locks, &set_lock, static_cast<int32>(start), + static_cast<int32>(end), &ready_to_split); }; - auto update_collated = [this, &target, &leaf_ids, &num_targets, - &leaf_depths, fertile_stats_resource, tree_resource, - &leaf_examples, &set_lock, &ready_to_split, + auto update_collated = [this, &target, &num_targets, fertile_stats_resource, + tree_resource, &leaf_examples, &set_lock, + &ready_to_split, num_leaves](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_leaves); UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set_, - target, num_targets, leaf_examples, leaf_depths, - &set_lock, static_cast<int32>(start), - static_cast<int32>(end), &ready_to_split); + target, num_targets, leaf_examples, &set_lock, + static_cast<int32>(start), static_cast<int32>(end), + &ready_to_split); }; if (param_proto_.collate_examples()) { @@ -411,7 +383,8 @@ class GrowTreeOp : public OpKernel { const int32 num_nodes = static_cast<int32>(finished_nodes.shape().dim_size(0)); - // TODO(gilberth): distribute this work over a number of threads. + // This op takes so little of the time for one batch that it isn't worth + // threading this. for (int i = 0; i < num_nodes && tree_resource->decision_tree().decision_tree().nodes_size() < @@ -420,16 +393,14 @@ class GrowTreeOp : public OpKernel { const int32 node = finished(i); std::unique_ptr<SplitCandidate> best(new SplitCandidate); int32 parent_depth; + // TODO(gilberth): Pushing these to an output would allow the complete + // decoupling of tree from resource. bool found = fertile_stats_resource->BestSplit(node, best.get(), &parent_depth); if (found) { std::vector<int32> new_children; tree_resource->SplitNode(node, best.get(), &new_children); fertile_stats_resource->Allocate(parent_depth, new_children); - fertile_stats_resource->set_leaf_stat(best->left_stats(), - new_children[0]); - fertile_stats_resource->set_leaf_stat(best->right_stats(), - new_children[1]); // We are done with best, so it is now safe to clear node. fertile_stats_resource->Clear(node); CHECK(tree_resource->get_mutable_tree_node(node)->has_leaf() == false); @@ -444,20 +415,17 @@ class GrowTreeOp : public OpKernel { TensorForestParams param_proto_; }; -void FinalizeLeaf(const LeafStat& leaf_stats, bool is_regression, - bool drop_final_class, +void FinalizeLeaf(bool is_regression, bool drop_final_class, const std::unique_ptr<LeafModelOperator>& leaf_op, decision_trees::Leaf* leaf) { - leaf_op->ExportModel(leaf_stats, leaf); - - // TODO(thomaswc): Move the rest of this into ExportModel. - // regression models are already stored in leaf in normalized form. if (is_regression) { return; } - float sum = leaf_stats.weight_sum(); + // TODO(gilberth): Calculate the leaf's sum. + float sum = 0; + LOG(FATAL) << "FinalizeTreeOp is disabled for now."; if (sum <= 0.0) { LOG(WARNING) << "Leaf with sum " << sum << " has stats " << leaf->ShortDebugString(); @@ -517,8 +485,7 @@ class FinalizeTreeOp : public OpKernel { ->mutable_decision_tree() ->mutable_nodes(i); if (node->has_leaf()) { - const auto& leaf_stats = fertile_stats_resource->leaf_stat(i); - FinalizeLeaf(leaf_stats, param_proto_.is_regression(), + FinalizeLeaf(param_proto_.is_regression(), param_proto_.drop_final_class(), model_op_, node->mutable_leaf()); } diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc index e5b86b05206..b3aa3a96f43 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc @@ -45,7 +45,7 @@ TEST(StatsOpsTest, GrowTreeV4_ShapeFn) { TEST(StatsOpsTest, ProcessInputV4_ShapeFn) { ShapeInferenceTestOp op("ProcessInputV4"); - INFER_OK(op, "[1];[1];?;?;?;?;?;?", "[?]"); + INFER_OK(op, "[1];[1];?;?;?;?;?;?;?", "[?]"); } TEST(StatsOpsTest, FinalizeTree_ShapeFn) { diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc index 165685ca53b..881e4339a75 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc @@ -18,6 +18,7 @@ namespace tensorflow { namespace tensorforest { using decision_trees::DecisionTree; +using decision_trees::Leaf; using decision_trees::TreeNode; int32 DecisionTreeResource::TraverseTree( @@ -51,13 +52,15 @@ void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best, new_children->push_back(newid); TreeNode* new_left = tree->add_nodes(); new_left->mutable_node_id()->set_value(newid++); - new_left->mutable_leaf(); + Leaf* left_leaf = new_left->mutable_leaf(); + model_op_->ExportModel(best->left_stats(), left_leaf); // right new_children->push_back(newid); TreeNode* new_right = tree->add_nodes(); new_right->mutable_node_id()->set_value(newid); - new_right->mutable_leaf(); + Leaf* right_leaf = new_right->mutable_leaf(); + model_op_->ExportModel(best->right_stats(), right_leaf); node->clear_leaf(); node->mutable_binary_node()->Swap(best->mutable_split()); @@ -72,7 +75,7 @@ void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best, void DecisionTreeResource::MaybeInitialize() { DecisionTree* tree = decision_tree_->mutable_decision_tree(); if (tree->nodes_size() == 0) { - tree->add_nodes()->mutable_leaf(); + model_op_->InitModel(tree->add_nodes()->mutable_leaf()); } else if (node_evaluators_.empty()) { // reconstruct evaluators for (const auto& node : tree->nodes()) { if (node.has_leaf()) { diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h index c8f09d8e075..438d3d817c4 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h @@ -31,8 +31,10 @@ namespace tensorforest { class DecisionTreeResource : public ResourceBase { public: // Constructor. - explicit DecisionTreeResource() - : decision_tree_(new decision_trees::Model()) {} + explicit DecisionTreeResource(const TensorForestParams& params) + : params_(params), decision_tree_(new decision_trees::Model()) { + model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_); + } string DebugString() override { return strings::StrCat("DecisionTree[size=", @@ -79,7 +81,9 @@ class DecisionTreeResource : public ResourceBase { private: mutex mu_; + const TensorForestParams params_; std::unique_ptr<decision_trees::Model> decision_tree_; + std::shared_ptr<LeafModelOperator> model_op_; std::vector<std::unique_ptr<DecisionNodeEvaluator>> node_evaluators_; }; diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc index 5c1b7454ae6..7f914aac319 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc @@ -20,14 +20,8 @@ namespace tensorflow { namespace tensorforest { void FertileStatsResource::AddExampleToStatsAndInitialize( - const std::unique_ptr<TensorDataSet>& input_data, - const InputTarget* target, const std::vector<int>& examples, - int32 node_id, int32 node_depth, bool* is_finished) { - // Set leaf's counts for calculating probabilities. - for (int example : examples) { - model_op_->UpdateModel(&leaf_stats_[node_id], target, example); - } - + const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, + const std::vector<int>& examples, int32 node_id, bool* is_finished) { // Update stats or initialize if needed. if (collection_op_->IsInitialized(node_id)) { collection_op_->AddExample(input_data, target, examples, node_id); @@ -47,8 +41,6 @@ void FertileStatsResource::AddExampleToStatsAndInitialize( } void FertileStatsResource::AllocateNode(int32 node_id, int32 depth) { - leaf_stats_[node_id] = LeafStat(); - model_op_->InitModel(&leaf_stats_[node_id]); collection_op_->InitializeSlot(node_id, depth); } @@ -62,7 +54,6 @@ void FertileStatsResource::Allocate(int32 parent_depth, void FertileStatsResource::Clear(int32 node) { collection_op_->ClearSlot(node); - leaf_stats_.erase(node); } bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best, @@ -71,27 +62,16 @@ bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best, } void FertileStatsResource::MaybeInitialize() { - if (leaf_stats_.empty()) { - AllocateNode(0, 0); - } + collection_op_->MaybeInitialize(); } void FertileStatsResource::ExtractFromProto(const FertileStats& stats) { collection_op_ = SplitCollectionOperatorFactory::CreateSplitCollectionOperator(params_); collection_op_->ExtractFromProto(stats); - for (int i = 0; i < stats.node_to_slot_size(); ++i) { - const auto& slot = stats.node_to_slot(i); - leaf_stats_[slot.node_id()] = slot.leaf_stats(); - } } void FertileStatsResource::PackToProto(FertileStats* stats) const { - for (const auto& entry : leaf_stats_) { - auto* slot = stats->add_node_to_slot(); - *slot->mutable_leaf_stats() = entry.second; - slot->set_node_id(entry.first); - } collection_op_->PackToProto(stats); } } // namespace tensorforest diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h index 34ec945e846..dacf033d990 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h @@ -51,7 +51,6 @@ class FertileStatsResource : public ResourceBase { // Resets the resource and frees the proto. // Caller needs to hold the mutex lock while calling this. void Reset() { - leaf_stats_.clear(); } // Reset the stats for a node, but leave the leaf_stats intact. @@ -71,7 +70,7 @@ class FertileStatsResource : public ResourceBase { void AddExampleToStatsAndInitialize( const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, const std::vector<int>& examples, - int32 node_id, int32 node_depth, bool* is_finished); + int32 node_id, bool* is_finished); // Allocate a fertile slot for each ready node, then new children up to // max_fertile_nodes_. @@ -85,19 +84,11 @@ class FertileStatsResource : public ResourceBase { // was found. bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth); - const LeafStat& leaf_stat(int32 node_id) { - return leaf_stats_[node_id]; - } - - void set_leaf_stat(const LeafStat& stat, int32 node_id) { - leaf_stats_[node_id] = stat; - } private: mutex mu_; std::shared_ptr<LeafModelOperator> model_op_; std::unique_ptr<SplitCollectionOperator> collection_op_; - std::unordered_map<int32, LeafStat> leaf_stats_; const TensorForestParams params_; void AllocateNode(int32 node_id, int32 depth); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc index 49e425642d1..d43c068e462 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc @@ -17,6 +17,8 @@ namespace tensorflow { namespace tensorforest { +using decision_trees::Leaf; + std::unique_ptr<LeafModelOperator> LeafModelOperatorFactory::CreateLeafModelOperator( const TensorForestParams& params) { @@ -50,24 +52,21 @@ float DenseClassificationLeafModelOperator::GetOutputValue( } void DenseClassificationLeafModelOperator::UpdateModel( - LeafStat* leaf, const InputTarget* target, - int example) const { + Leaf* leaf, const InputTarget* target, int example) const { const int32 int_label = target->GetTargetAsClassIndex(example, 0); QCHECK_LT(int_label, params_.num_outputs()) << "Got label greater than indicated number of classes. Is " "params.num_classes set correctly?"; QCHECK_GE(int_label, 0); - auto* val = leaf->mutable_classification()->mutable_dense_counts() - ->mutable_value(int_label); + auto* val = leaf->mutable_vector()->mutable_value(int_label); + float weight = target->GetTargetWeight(example); val->set_float_value(val->float_value() + weight); - leaf->set_weight_sum(leaf->weight_sum() + weight); } -void DenseClassificationLeafModelOperator::InitModel( - LeafStat* leaf) const { +void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const { for (int i = 0; i < params_.num_outputs(); ++i) { - leaf->mutable_classification()->mutable_dense_counts()->add_value(); + leaf->mutable_vector()->add_value(); } } @@ -88,17 +87,15 @@ float SparseClassificationLeafModelOperator::GetOutputValue( } void SparseClassificationLeafModelOperator::UpdateModel( - LeafStat* leaf, const InputTarget* target, - int example) const { + Leaf* leaf, const InputTarget* target, int example) const { const int32 int_label = target->GetTargetAsClassIndex(example, 0); QCHECK_LT(int_label, params_.num_outputs()) << "Got label greater than indicated number of classes. Is " "params.num_classes set correctly?"; QCHECK_GE(int_label, 0); const float weight = target->GetTargetWeight(example); - leaf->set_weight_sum(leaf->weight_sum() + weight); - auto value_map = leaf->mutable_classification()->mutable_sparse_counts() - ->mutable_sparse_value(); + + auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value(); auto it = value_map->find(int_label); if (it == value_map->end()) { (*value_map)[int_label].set_float_value(weight); @@ -123,8 +120,8 @@ float SparseOrDenseClassificationLeafModelOperator::GetOutputValue( } void SparseOrDenseClassificationLeafModelOperator::UpdateModel( - LeafStat* leaf, const InputTarget* target, int example) const { - if (leaf->classification().has_dense_counts()) { + Leaf* leaf, const InputTarget* target, int example) const { + if (leaf->has_vector()) { return dense_->UpdateModel(leaf, target, example); } else { return sparse_->UpdateModel(leaf, target, example); @@ -146,15 +143,15 @@ float RegressionLeafModelOperator::GetOutputValue( return leaf.vector().value(o).float_value(); } -void RegressionLeafModelOperator::InitModel( - LeafStat* leaf) const { +void RegressionLeafModelOperator::InitModel(Leaf* leaf) const { for (int i = 0; i < params_.num_outputs(); ++i) { - leaf->mutable_regression()->mutable_mean_output()->add_value(); + leaf->mutable_vector()->add_value(); } } void RegressionLeafModelOperator::ExportModel( const LeafStat& stat, decision_trees::Leaf* leaf) const { + leaf->clear_vector(); for (int i = 0; i < params_.num_outputs(); ++i) { const float new_val = stat.regression().mean_output().value(i).float_value() / diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h index 8aadefc4033..946a648f22f 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h @@ -42,12 +42,11 @@ class LeafModelOperator { int32 o) const = 0; // Update the given Leaf's model with the given example. - virtual void UpdateModel(LeafStat* leaf, - const InputTarget* target, - int example) const = 0; + virtual void UpdateModel(decision_trees::Leaf* leaf, + const InputTarget* target, int example) const = 0; // Initialize an empty Leaf model. - virtual void InitModel(LeafStat* leaf) const = 0; + virtual void InitModel(decision_trees::Leaf* leaf) const = 0; virtual void ExportModel(const LeafStat& stat, decision_trees::Leaf* leaf) const = 0; @@ -65,10 +64,10 @@ class DenseClassificationLeafModelOperator : public LeafModelOperator { float GetOutputValue(const decision_trees::Leaf& leaf, int32 o) const override; - void UpdateModel(LeafStat* leaf, const InputTarget* target, + void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target, int example) const override; - void InitModel(LeafStat* leaf) const override; + void InitModel(decision_trees::Leaf* leaf) const override; void ExportModel(const LeafStat& stat, decision_trees::Leaf* leaf) const override; @@ -84,10 +83,10 @@ class SparseClassificationLeafModelOperator : public LeafModelOperator { float GetOutputValue(const decision_trees::Leaf& leaf, int32 o) const override; - void UpdateModel(LeafStat* leaf, const InputTarget* target, + void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target, int example) const override; - void InitModel(LeafStat* leaf) const override {} + void InitModel(decision_trees::Leaf* leaf) const override {} void ExportModel(const LeafStat& stat, decision_trees::Leaf* leaf) const override; @@ -103,10 +102,10 @@ class SparseOrDenseClassificationLeafModelOperator : public LeafModelOperator { float GetOutputValue(const decision_trees::Leaf& leaf, int32 o) const override; - void UpdateModel(LeafStat* leaf, const InputTarget* target, + void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target, int example) const override; - void InitModel(LeafStat* leaf) const override {} + void InitModel(decision_trees::Leaf* leaf) const override {} void ExportModel(const LeafStat& stat, decision_trees::Leaf* leaf) const override; @@ -129,10 +128,10 @@ class RegressionLeafModelOperator : public LeafModelOperator { // updating model and just using the seeded values. Can add this in // with additional_data, though protobuf::Any is slow. Maybe make it // optional. Maybe make any update optional. - void UpdateModel(LeafStat* leaf, const InputTarget* target, + void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target, int example) const override {} - void InitModel(LeafStat* leaf) const override; + void InitModel(decision_trees::Leaf* leaf) const override; void ExportModel(const LeafStat& stat, decision_trees::Leaf* leaf) const override; diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc index 35268d15d3e..ffd92c01f9a 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc @@ -63,12 +63,8 @@ constexpr char kRegressionStatProto[] = "}"; void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) { - std::unique_ptr<LeafStat> leaf(new LeafStat); - op->InitModel(leaf.get()); - Leaf l; - op->ExportModel(*leaf, &l); - + op->InitModel(&l); // Make sure it was initialized correctly. for (int i = 0; i < kNumClasses; ++i) { EXPECT_EQ(op->GetOutputValue(l, i), 0); @@ -80,11 +76,10 @@ void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) { new TestableInputTarget(labels, weights, 1)); // Update and check value. - op->UpdateModel(leaf.get(), target.get(), 0); - op->UpdateModel(leaf.get(), target.get(), 1); - op->UpdateModel(leaf.get(), target.get(), 2); + op->UpdateModel(&l, target.get(), 0); + op->UpdateModel(&l, target.get(), 1); + op->UpdateModel(&l, target.get(), 2); - op->ExportModel(*leaf, &l); EXPECT_FLOAT_EQ(op->GetOutputValue(l, 1), 3.4); } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc index 632408fd718..ccc412600c7 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc @@ -71,13 +71,13 @@ void SplitCollectionOperator::ExtractFromProto( } void SplitCollectionOperator::PackToProto(FertileStats* stats_proto) const { - for (int i = 0; i < stats_proto->node_to_slot_size(); ++i) { - auto* new_slot = stats_proto->mutable_node_to_slot(i); - const auto& stats = stats_.at(new_slot->node_id()); + for (const auto& pair : stats_) { + auto* new_slot = stats_proto->add_node_to_slot(); + new_slot->set_node_id(pair.first); if (params_.checkpoint_stats()) { - stats->PackToProto(new_slot); + pair.second->PackToProto(new_slot); } - new_slot->set_depth(stats->depth()); + new_slot->set_depth(pair.second->depth()); } } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h index 6990e82678b..6c21c0bd344 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h @@ -62,6 +62,14 @@ class SplitCollectionOperator { // Create a new GrowStats for the given node id and initialize it. virtual void InitializeSlot(int32 node_id, int32 depth); + // Called when the resource is deserialized, possibly needing an + // initialization. + virtual void MaybeInitialize() { + if (stats_.empty()) { + InitializeSlot(0, 0); + } + } + // Perform any necessary cleanup for any tracked state for the slot. virtual void ClearSlot(int32 node_id) { stats_.erase(node_id); diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc index 168f079f523..1227a70a2e9 100644 --- a/tensorflow/contrib/tensor_forest/ops/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc @@ -115,6 +115,58 @@ sparse_input_shape: The shape tensor from the SparseTensor input. predictions: `predictions[i][j]` is the probability that input i is class j. )doc"); +REGISTER_OP("TraverseTreeV4") + .Attr("input_spec: string") + .Attr("params: string") + .Input("tree_handle: resource") + .Input("input_data: float") + .Input("sparse_input_indices: int64") + .Input("sparse_input_values: float") + .Input("sparse_input_shape: int64") + .Output("leaf_ids: int32") + .SetShapeFn([](InferenceContext* c) { + DimensionHandle num_points = c->UnknownDim(); + + if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0 && + c->Value(c->Dim(c->input(1), 0)) > 0) { + num_points = c->Dim(c->input(1), 0); + } + + c->set_output(0, c->Vector(num_points)); + return Status::OK(); + }) + .Doc(R"doc( +Outputs the leaf ids for the given input data. + +params: A serialized TensorForestParams proto. +tree_handle: The handle to the tree. +input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` + gives the j-th feature of the i-th input. +sparse_input_indices: The indices tensor from the SparseTensor input. +sparse_input_values: The values tensor from the SparseTensor input. +sparse_input_shape: The shape tensor from the SparseTensor input. +leaf_ids: `leaf_ids[i]` is the leaf id for input i. +)doc"); + +REGISTER_OP("UpdateModelV4") + .Attr("params: string") + .Input("tree_handle: resource") + .Input("leaf_ids: int32") + .Input("input_labels: float") + .Input("input_weights: float") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc(R"doc( +Updates the given leaves for each example with the new labels. + +params: A serialized TensorForestParams proto. +tree_handle: The handle to the tree. +leaf_ids: `leaf_ids[i]` is the leaf id for input i. +input_labels: The training batch's labels as a 1 or 2-d tensor. + 'input_labels[i][j]' gives the j-th label/target for the i-th input. +input_weights: The training batch's eample weights as a 1-d tensor. + 'input_weights[i]' gives the weight for the i-th input. +)doc"); + REGISTER_OP("FeatureUsageCounts") .Attr("params: string") .Input("tree_handle: resource") diff --git a/tensorflow/contrib/tensor_forest/ops/stats_ops.cc b/tensorflow/contrib/tensor_forest/ops/stats_ops.cc index 96527497689..e8b5c5d8a6e 100644 --- a/tensorflow/contrib/tensor_forest/ops/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/ops/stats_ops.cc @@ -98,6 +98,7 @@ REGISTER_OP("ProcessInputV4") .Input("sparse_input_shape: int64") .Input("input_labels: float") .Input("input_weights: float") + .Input("leaf_ids: int32") .Output("finished_nodes: int32") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(c->UnknownDim())); @@ -122,6 +123,7 @@ input_weights: The training batch's eample weights as a 1-d tensor. 'input_weights[i]' gives the weight for the i-th input. finished_nodes: A 1-d tensor of node ids that have finished and are ready to grow. +leaf_ids: `leaf_ids[i]` is the leaf id for input i. )doc"); REGISTER_OP("FinalizeTree") diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py index 4c7218305b5..d240e2f6dec 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py @@ -18,12 +18,13 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.tensor_forest.python.ops import gen_model_ops -from tensorflow.contrib.tensor_forest.python.ops import stats_ops # pylint: disable=unused-import from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import feature_usage_counts +from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import traverse_tree_v4 from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_predictions_v4 from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_size +from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import update_model_v4 # pylint: enable=unused-import from tensorflow.contrib.util import loader @@ -59,13 +60,7 @@ class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject): name: the name to save the tree variable under. """ self.params = params - deps = [] - if stats_handle is not None: - deps.append(stats_ops.finalize_tree( - tree_handle, stats_handle, - params=params.serialized_params_proto)) - with ops.control_dependencies(deps): - tensor = gen_model_ops.tree_serialize(tree_handle) + tensor = gen_model_ops.tree_serialize(tree_handle) # slice_spec is useful for saving a slice from a variable. # It's not meaningful the tree variable. So we just pass an empty value. slice_spec = "" diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py index 7e6f00a13df..8198c228dd6 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py @@ -27,6 +27,7 @@ from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.contrib.tensor_forest.python.ops import model_ops from tensorflow.contrib.tensor_forest.python.ops import stats_ops from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import tf_logging as logging @@ -240,6 +241,22 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs): if input_data is None: input_data = [] + leaf_ids = model_ops.traverse_tree_v4( + self.variables.tree, + input_data, + sparse_indices, + sparse_values, + sparse_shape, + input_spec=data_spec.SerializeToString(), + params=self.params.serialized_params_proto) + + update_model = model_ops.update_model_v4( + self.variables.tree, + leaf_ids, + input_labels, + input_weights, + params=self.params.serialized_params_proto) + finished_nodes = stats_ops.process_input_v4( self.variables.tree, self.variables.stats, @@ -249,13 +266,17 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs): sparse_shape, input_labels, input_weights, + leaf_ids, input_spec=data_spec.SerializeToString(), random_seed=random_seed, params=self.params.serialized_params_proto) - return stats_ops.grow_tree_v4(self.variables.tree, self.variables.stats, - finished_nodes, - params=self.params.serialized_params_proto) + with ops.control_dependencies([update_model]): + return stats_ops.grow_tree_v4( + self.variables.tree, + self.variables.stats, + finished_nodes, + params=self.params.serialized_params_proto) def inference_graph(self, input_data, data_spec, sparse_features=None): sparse_indices = []