From 786bf6cd656d0d67e56bf50047ff116bae884b9e Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 12 Jul 2017 07:37:31 -0700
Subject: [PATCH] Refactor some of TensorForest V4 to make the tree model valid
 during training time, instead of only after FinalizeTreeOp.

PiperOrigin-RevId: 161663317
---
 .../tensor_forest/kernels/model_ops.cc        | 183 ++++++++++++++++--
 .../tensor_forest/kernels/model_ops_test.cc   |  27 +++
 .../tensor_forest/kernels/stats_ops.cc        | 101 ++++------
 .../tensor_forest/kernels/stats_ops_test.cc   |   2 +-
 .../kernels/v4/decision-tree-resource.cc      |   9 +-
 .../kernels/v4/decision-tree-resource.h       |   8 +-
 .../kernels/v4/fertile-stats-resource.cc      |  26 +--
 .../kernels/v4/fertile-stats-resource.h       |  11 +-
 .../kernels/v4/leaf_model_operators.cc        |  33 ++--
 .../kernels/v4/leaf_model_operators.h         |  23 ++-
 .../kernels/v4/leaf_model_operators_test.cc   |  13 +-
 .../kernels/v4/split_collection_operators.cc  |  10 +-
 .../kernels/v4/split_collection_operators.h   |   8 +
 .../contrib/tensor_forest/ops/model_ops.cc    |  52 +++++
 .../contrib/tensor_forest/ops/stats_ops.cc    |   2 +
 .../tensor_forest/python/ops/model_ops.py     |  11 +-
 .../tensor_forest/python/tensor_forest_v4.py  |  27 ++-
 17 files changed, 374 insertions(+), 172 deletions(-)

diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
index 0f92f05e2c0..221f8d969bc 100644
--- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
@@ -12,6 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 // =============================================================================
+#include <functional>
 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
 #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
@@ -26,6 +27,7 @@
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
 
 namespace tensorflow {
 namespace tensorforest {
@@ -46,7 +48,7 @@ class CreateTreeVariableOp : public OpKernel {
     OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()),
                 errors::InvalidArgument("Tree config must be a scalar."));
 
-    auto* result = new DecisionTreeResource();
+    auto* result = new DecisionTreeResource(param_proto_);
     if (!ParseProtoUnlimited(result->mutable_decision_tree(),
                              tree_config_t->scalar<string>()())) {
       result->Unref();
@@ -142,6 +144,16 @@ class TreeSizeOp : public OpKernel {
   }
 };
 
+void TraverseTree(const DecisionTreeResource* tree_resource,
+                  const std::unique_ptr<TensorDataSet>& data, int32 start,
+                  int32 end,
+                  const std::function<void(int32, int32)>& set_leaf_id) {
+  for (int i = start; i < end; ++i) {
+    const int32 id = tree_resource->TraverseTree(data, i, nullptr);
+    set_leaf_id(i, id);
+  }
+}
+
 // Op for tree inference.
 class TreePredictionsV4Op : public OpKernel {
  public:
@@ -176,22 +188,49 @@ class TreePredictionsV4Op : public OpKernel {
     mutex_lock l(*decision_tree_resource->get_mutex());
     core::ScopedUnref unref_me(decision_tree_resource);
 
+    const int num_data = data_set_->NumItems();
+    const int32 num_outputs = param_proto_.num_outputs();
+
     Tensor* output_predictions = nullptr;
     TensorShape output_shape;
-    output_shape.AddDim(data_set_->NumItems());
-    output_shape.AddDim(param_proto_.num_outputs());
+    output_shape.AddDim(num_data);
+    output_shape.AddDim(num_outputs);
     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
                                                      &output_predictions));
+    TTypes<float, 2>::Tensor out = output_predictions->tensor<float, 2>();
 
-    auto out = output_predictions->tensor<float, 2>();
-    for (int i = 0; i < data_set_->NumItems(); ++i) {
-      const int32 leaf_id =
-          decision_tree_resource->TraverseTree(data_set_, i, nullptr);
-      const decision_trees::Leaf& leaf =
-          decision_tree_resource->get_leaf(leaf_id);
+    auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+    int num_threads = worker_threads->num_threads;
+    const int64 costPerTraverse = 500;
+    auto traverse = [this, &out, decision_tree_resource, num_data](int64 start,
+                                                                   int64 end) {
+      CHECK(start <= end);
+      CHECK(end <= num_data);
+      TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
+                   static_cast<int32>(end),
+                   std::bind(&TreePredictionsV4Op::set_output_value, this,
+                             std::placeholders::_1, std::placeholders::_2,
+                             decision_tree_resource, &out));
+    };
+    Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
+          traverse);
+  }
+
+  void set_output_value(int32 i, int32 id,
+                        DecisionTreeResource* decision_tree_resource,
+                        TTypes<float, 2>::Tensor* out) {
+    const decision_trees::Leaf& leaf = decision_tree_resource->get_leaf(id);
+
+    float sum = 0;
+    for (int j = 0; j < param_proto_.num_outputs(); ++j) {
+      const float count = model_op_->GetOutputValue(leaf, j);
+      (*out)(i, j) = count;
+      sum += count;
+    }
+
+    if (!param_proto_.is_regression() && sum > 0 && sum != 1) {
       for (int j = 0; j < param_proto_.num_outputs(); ++j) {
-        const float count = model_op_->GetOutputValue(leaf, j);
-        out(i, j) = count;
+        (*out)(i, j) /= sum;
       }
     }
   }
@@ -203,6 +242,122 @@ class TreePredictionsV4Op : public OpKernel {
   TensorForestParams param_proto_;
 };
 
+// Outputs leaf ids for the given examples.
+class TraverseTreeV4Op : public OpKernel {
+ public:
+  explicit TraverseTreeV4Op(OpKernelConstruction* context) : OpKernel(context) {
+    string serialized_params;
+    OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+    ParseProtoUnlimited(&param_proto_, serialized_params);
+
+    string serialized_proto;
+    OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
+    input_spec_.ParseFromString(serialized_proto);
+
+    data_set_ =
+        std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& input_data = context->input(1);
+    const Tensor& sparse_input_indices = context->input(2);
+    const Tensor& sparse_input_values = context->input(3);
+    const Tensor& sparse_input_shape = context->input(4);
+
+    data_set_->set_input_tensors(input_data, sparse_input_indices,
+                                 sparse_input_values, sparse_input_shape);
+
+    DecisionTreeResource* decision_tree_resource;
+    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+                                           &decision_tree_resource));
+    mutex_lock l(*decision_tree_resource->get_mutex());
+    core::ScopedUnref unref_me(decision_tree_resource);
+
+    const int num_data = data_set_->NumItems();
+
+    Tensor* output_predictions = nullptr;
+    TensorShape output_shape;
+    output_shape.AddDim(num_data);
+    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
+                                                     &output_predictions));
+
+    auto leaf_ids = output_predictions->tensor<int32, 1>();
+
+    auto set_leaf_ids = [&leaf_ids](int32 i, int32 id) { leaf_ids(i) = id; };
+
+    auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+    int num_threads = worker_threads->num_threads;
+    const int64 costPerTraverse = 500;
+    auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data](
+                        int64 start, int64 end) {
+      CHECK(start <= end);
+      CHECK(end <= num_data);
+      TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
+                   static_cast<int32>(end), set_leaf_ids);
+    };
+    Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
+          traverse);
+  }
+
+ private:
+  tensorforest::TensorForestDataSpec input_spec_;
+  std::unique_ptr<TensorDataSet> data_set_;
+  TensorForestParams param_proto_;
+};
+
+// Update the given leaf models using the batch of labels.
+class UpdateModelV4Op : public OpKernel {
+ public:
+  explicit UpdateModelV4Op(OpKernelConstruction* context) : OpKernel(context) {
+    string serialized_params;
+    OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+    ParseProtoUnlimited(&param_proto_, serialized_params);
+
+    model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_);
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& leaf_ids = context->input(1);
+    const Tensor& input_labels = context->input(2);
+    const Tensor& input_weights = context->input(3);
+
+    DecisionTreeResource* decision_tree_resource;
+    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+                                           &decision_tree_resource));
+    mutex_lock l(*decision_tree_resource->get_mutex());
+    core::ScopedUnref unref_me(decision_tree_resource);
+
+    const int num_data = input_labels.shape().dim_size(0);
+    const int32 label_dim =
+        input_labels.shape().dims() <= 1
+            ? 0
+            : static_cast<int>(input_labels.shape().dim_size(1));
+    const int32 num_targets =
+        param_proto_.is_regression() ? (std::max(1, label_dim)) : 1;
+
+    TensorInputTarget target(input_labels, input_weights, num_targets);
+
+    // TODO(gilberth): Make this thread safe and multi-thread.
+    UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource);
+  }
+
+  void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target,
+                   int32 start, int32 end,
+                   DecisionTreeResource* decision_tree_resource) {
+    const auto leaves = leaf_ids.unaligned_flat<int32>();
+    for (int i = start; i < end; ++i) {
+      model_op_->UpdateModel(
+          decision_tree_resource->get_mutable_tree_node(leaves(i))
+              ->mutable_leaf(),
+          &target, i);
+    }
+  }
+
+ private:
+  std::unique_ptr<LeafModelOperator> model_op_;
+  TensorForestParams param_proto_;
+};
+
 // Op for getting feature usage counts.
 class FeatureUsageCountsOp : public OpKernel {
  public:
@@ -286,8 +441,14 @@ REGISTER_KERNEL_BUILDER(Name("TreeSize").Device(DEVICE_CPU), TreeSizeOp);
 REGISTER_KERNEL_BUILDER(Name("TreePredictionsV4").Device(DEVICE_CPU),
                         TreePredictionsV4Op);
 
+REGISTER_KERNEL_BUILDER(Name("TraverseTreeV4").Device(DEVICE_CPU),
+                        TraverseTreeV4Op);
+
 REGISTER_KERNEL_BUILDER(Name("FeatureUsageCounts").Device(DEVICE_CPU),
                         FeatureUsageCountsOp);
 
+REGISTER_KERNEL_BUILDER(Name("UpdateModelV4").Device(DEVICE_CPU),
+                        UpdateModelV4Op);
+
 }  // namespace tensorforest
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc
index cece61a54c5..0fdab8e6e0b 100644
--- a/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc
@@ -64,6 +64,33 @@ TEST(ModelOpsTest, TreePredictionsV4_ShapeFn) {
   INFER_OK(op, "?;?;?;?;[10,11]", "[?,?]");
 }
 
+TEST(ModelOpsTest, TraverseTreeV4_ShapeFn) {
+  ShapeInferenceTestOp op("TraverseTreeV4");
+  TF_ASSERT_OK(NodeDefBuilder("test", "TraverseTreeV4")
+                   .Input("a", 0, DT_RESOURCE)
+                   .Input("b", 1, DT_FLOAT)
+                   .Input("c", 2, DT_INT64)
+                   .Input("d", 3, DT_FLOAT)
+                   .Input("e", 5, DT_INT64)
+                   .Attr("input_spec", "")
+                   .Attr("params", "")
+                   .Finalize(&op.node_def));
+
+  // num_points = 2, sparse shape not known
+  INFER_OK(op, "?;[2,3];?;?;?", "[d1_0]");
+
+  // num_points = 2, sparse and dense shape rank known and > 1
+  INFER_OK(op, "?;[2,3];?;?;[10,11]", "[d1_0]");
+
+  // num_points = 2, sparse shape rank known and > 1
+  INFER_OK(op, "?;?;?;?;[10,11]", "[?]");
+}
+
+TEST(ModelOpsTest, UpdateModelV4_ShapeFn) {
+  ShapeInferenceTestOp op("UpdateModelV4");
+  INFER_OK(op, "[1];?;?;?", "");
+}
+
 TEST(ModelOpsTest, FeatureUsageCounts_ShapeFn) {
   ShapeInferenceTestOp op("FeatureUsageCounts");
   INFER_OK(op, "[1]", "[?]");
diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
index 260e03df262..b6d57ef9527 100644
--- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
@@ -141,18 +141,6 @@ class FertileStatsDeserializeOp : public OpKernel {
   TensorForestParams param_proto_;
 };
 
-void TraverseTree(const DecisionTreeResource* tree_resource,
-                  const std::unique_ptr<TensorDataSet>& data, int32 start,
-                  int32 end, std::vector<int32>* leaf_ids,
-                  std::vector<int32>* leaf_depths) {
-  for (int i = start; i < end; ++i) {
-    int32 depth;
-    const int32 leaf_id = tree_resource->TraverseTree(data, i, &depth);
-    (*leaf_ids)[i] = leaf_id;
-    (*leaf_depths)[i] = depth;
-  }
-}
-
 // Try to update a leaf's stats by acquiring its lock.  If it can't be
 // acquired, put it in a waiting queue to come back to later and try the next
 // one.  Once all leaf_ids have been visited, cycle through the waiting ids
@@ -160,28 +148,27 @@ void TraverseTree(const DecisionTreeResource* tree_resource,
 void UpdateStats(FertileStatsResource* fertile_stats_resource,
                  const std::unique_ptr<TensorDataSet>& data,
                  const TensorInputTarget& target, int num_targets,
-                 const std::vector<int32>& leaf_ids,
-                 const std::vector<int32>& leaf_depths,
+                 const Tensor& leaf_ids_tensor,
                  std::unordered_map<int32, std::unique_ptr<mutex>>* locks,
                  mutex* set_lock, int32 start, int32 end,
                  std::unordered_set<int32>* ready_to_split) {
+  const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
+
   // Stores leaf_id, leaf_depth, example_id for examples that are waiting
   // on another to finish.
-  std::queue<std::tuple<int32, int32, int32>> waiting;
+  std::queue<std::tuple<int32, int32>> waiting;
 
   int32 i = start;
   while (i < end || !waiting.empty()) {
     int32 leaf_id;
-    int32 leaf_depth;
     int32 example_id;
     bool was_waiting = false;
     if (i >= end) {
-      std::tie(leaf_id, leaf_depth, example_id) = waiting.front();
+      std::tie(leaf_id, example_id) = waiting.front();
       waiting.pop();
       was_waiting = true;
     } else {
-      leaf_id = leaf_ids[i];
-      leaf_depth = leaf_depths[i];
+      leaf_id = leaf_ids(i);
       example_id = i;
       ++i;
     }
@@ -190,14 +177,14 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
       leaf_lock->lock();
     } else {
       if (!leaf_lock->try_lock()) {
-        waiting.emplace(leaf_id, leaf_depth, example_id);
+        waiting.emplace(leaf_id, example_id);
         continue;
       }
     }
 
     bool is_finished;
     fertile_stats_resource->AddExampleToStatsAndInitialize(
-        data, &target, {example_id}, leaf_id, leaf_depth, &is_finished);
+        data, &target, {example_id}, leaf_id, &is_finished);
     leaf_lock->unlock();
     if (is_finished) {
       set_lock->lock();
@@ -214,8 +201,8 @@ void UpdateStatsCollated(
     const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
     int num_targets,
     const std::unordered_map<int32, std::vector<int>>& leaf_examples,
-    const std::vector<int32>& leaf_depths, mutex* set_lock, int32 start,
-    int32 end, std::unordered_set<int32>* ready_to_split) {
+    mutex* set_lock, int32 start, int32 end,
+    std::unordered_set<int32>* ready_to_split) {
   auto it = leaf_examples.begin();
   std::advance(it, start);
   auto end_it = leaf_examples.begin();
@@ -224,8 +211,7 @@ void UpdateStatsCollated(
     int32 leaf_id = it->first;
     bool is_finished;
     fertile_stats_resource->AddExampleToStatsAndInitialize(
-        data, &target, it->second, leaf_id, leaf_depths[it->second[0]],
-        &is_finished);
+        data, &target, it->second, leaf_id, &is_finished);
     if (is_finished) {
       set_lock->lock();
       ready_to_split->insert(leaf_id);
@@ -261,6 +247,7 @@ class ProcessInputOp : public OpKernel {
     const Tensor& sparse_input_shape = context->input(5);
     const Tensor& input_labels = context->input(6);
     const Tensor& input_weights = context->input(7);
+    const Tensor& leaf_ids_tensor = context->input(8);
 
     data_set_->set_input_tensors(input_data, sparse_input_indices,
                                  sparse_input_values, sparse_input_shape);
@@ -281,22 +268,7 @@ class ProcessInputOp : public OpKernel {
     auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
     int num_threads = worker_threads->num_threads;
 
-    // First find the leaf ids for each example.
-    std::vector<int32> leaf_ids(num_data);
-
-    // The depth of the leaf for example i.
-    std::vector<int32> leaf_depths(num_data);
-
-    const int64 costPerTraverse = 500;
-    auto traverse = [this, &leaf_ids, &leaf_depths, tree_resource, num_data](
-                        int64 start, int64 end) {
-      CHECK(start <= end);
-      CHECK(end <= num_data);
-      TraverseTree(tree_resource, data_set_, static_cast<int32>(start),
-                   static_cast<int32>(end), &leaf_ids, &leaf_depths);
-    };
-    Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
-          traverse);
+    const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
 
     // Create one mutex per leaf. We need to protect access to leaf pointers,
     // so instead of grouping examples by leaf, we spread examples out among
@@ -306,10 +278,11 @@ class ProcessInputOp : public OpKernel {
     std::unordered_map<int32, std::vector<int>> leaf_examples;
     if (param_proto_.collate_examples()) {
       for (int i = 0; i < num_data; ++i) {
-        leaf_examples[leaf_ids[i]].push_back(i);
+        leaf_examples[leaf_ids(i)].push_back(i);
       }
     } else {
-      for (const int32 id : leaf_ids) {
+      for (int i = 0; i < num_data; ++i) {
+        const int32 id = leaf_ids(i);
         if (FindOrNull(locks, id) == nullptr) {
           // TODO(gilberth): Consider using a memory pool for these.
           locks[id] = std::unique_ptr<mutex>(new mutex);
@@ -335,27 +308,26 @@ class ProcessInputOp : public OpKernel {
     // from a digits run on local desktop.  Heuristics might be necessary
     // if it really matters that much.
     const int64 costPerUpdate = 1000;
-    auto update = [this, &target, &leaf_ids, &leaf_depths, &num_targets,
+    auto update = [this, &target, &leaf_ids_tensor, &num_targets,
                    fertile_stats_resource, &locks, &set_lock, &ready_to_split,
                    num_data](int64 start, int64 end) {
       CHECK(start <= end);
       CHECK(end <= num_data);
       UpdateStats(fertile_stats_resource, data_set_, target, num_targets,
-                  leaf_ids, leaf_depths, &locks, &set_lock,
-                  static_cast<int32>(start), static_cast<int32>(end),
-                  &ready_to_split);
+                  leaf_ids_tensor, &locks, &set_lock, static_cast<int32>(start),
+                  static_cast<int32>(end), &ready_to_split);
     };
 
-    auto update_collated = [this, &target, &leaf_ids, &num_targets,
-                            &leaf_depths, fertile_stats_resource, tree_resource,
-                            &leaf_examples, &set_lock, &ready_to_split,
+    auto update_collated = [this, &target, &num_targets, fertile_stats_resource,
+                            tree_resource, &leaf_examples, &set_lock,
+                            &ready_to_split,
                             num_leaves](int64 start, int64 end) {
       CHECK(start <= end);
       CHECK(end <= num_leaves);
       UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set_,
-                          target, num_targets, leaf_examples, leaf_depths,
-                          &set_lock, static_cast<int32>(start),
-                          static_cast<int32>(end), &ready_to_split);
+                          target, num_targets, leaf_examples, &set_lock,
+                          static_cast<int32>(start), static_cast<int32>(end),
+                          &ready_to_split);
     };
 
     if (param_proto_.collate_examples()) {
@@ -411,7 +383,8 @@ class GrowTreeOp : public OpKernel {
     const int32 num_nodes =
         static_cast<int32>(finished_nodes.shape().dim_size(0));
 
-    // TODO(gilberth): distribute this work over a number of threads.
+    // This op takes so little of the time for one batch that it isn't worth
+    // threading this.
     for (int i = 0;
          i < num_nodes &&
          tree_resource->decision_tree().decision_tree().nodes_size() <
@@ -420,16 +393,14 @@ class GrowTreeOp : public OpKernel {
       const int32 node = finished(i);
       std::unique_ptr<SplitCandidate> best(new SplitCandidate);
       int32 parent_depth;
+      // TODO(gilberth): Pushing these to an output would allow the complete
+      // decoupling of tree from resource.
       bool found =
           fertile_stats_resource->BestSplit(node, best.get(), &parent_depth);
       if (found) {
         std::vector<int32> new_children;
         tree_resource->SplitNode(node, best.get(), &new_children);
         fertile_stats_resource->Allocate(parent_depth, new_children);
-        fertile_stats_resource->set_leaf_stat(best->left_stats(),
-                                              new_children[0]);
-        fertile_stats_resource->set_leaf_stat(best->right_stats(),
-                                              new_children[1]);
         // We are done with best, so it is now safe to clear node.
         fertile_stats_resource->Clear(node);
         CHECK(tree_resource->get_mutable_tree_node(node)->has_leaf() == false);
@@ -444,20 +415,17 @@ class GrowTreeOp : public OpKernel {
   TensorForestParams param_proto_;
 };
 
-void FinalizeLeaf(const LeafStat& leaf_stats, bool is_regression,
-                  bool drop_final_class,
+void FinalizeLeaf(bool is_regression, bool drop_final_class,
                   const std::unique_ptr<LeafModelOperator>& leaf_op,
                   decision_trees::Leaf* leaf) {
-  leaf_op->ExportModel(leaf_stats, leaf);
-
-  // TODO(thomaswc): Move the rest of this into ExportModel.
-
   // regression models are already stored in leaf in normalized form.
   if (is_regression) {
     return;
   }
 
-  float sum = leaf_stats.weight_sum();
+  // TODO(gilberth): Calculate the leaf's sum.
+  float sum = 0;
+  LOG(FATAL) << "FinalizeTreeOp is disabled for now.";
   if (sum <= 0.0) {
     LOG(WARNING) << "Leaf with sum " << sum << " has stats "
                  << leaf->ShortDebugString();
@@ -517,8 +485,7 @@ class FinalizeTreeOp : public OpKernel {
                        ->mutable_decision_tree()
                        ->mutable_nodes(i);
       if (node->has_leaf()) {
-        const auto& leaf_stats = fertile_stats_resource->leaf_stat(i);
-        FinalizeLeaf(leaf_stats, param_proto_.is_regression(),
+        FinalizeLeaf(param_proto_.is_regression(),
                      param_proto_.drop_final_class(), model_op_,
                      node->mutable_leaf());
       }
diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc
index e5b86b05206..b3aa3a96f43 100644
--- a/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc
@@ -45,7 +45,7 @@ TEST(StatsOpsTest, GrowTreeV4_ShapeFn) {
 
 TEST(StatsOpsTest, ProcessInputV4_ShapeFn) {
   ShapeInferenceTestOp op("ProcessInputV4");
-  INFER_OK(op, "[1];[1];?;?;?;?;?;?", "[?]");
+  INFER_OK(op, "[1];[1];?;?;?;?;?;?;?", "[?]");
 }
 
 TEST(StatsOpsTest, FinalizeTree_ShapeFn) {
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc
index 165685ca53b..881e4339a75 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc
@@ -18,6 +18,7 @@ namespace tensorflow {
 namespace tensorforest {
 
 using decision_trees::DecisionTree;
+using decision_trees::Leaf;
 using decision_trees::TreeNode;
 
 int32 DecisionTreeResource::TraverseTree(
@@ -51,13 +52,15 @@ void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best,
   new_children->push_back(newid);
   TreeNode* new_left = tree->add_nodes();
   new_left->mutable_node_id()->set_value(newid++);
-  new_left->mutable_leaf();
+  Leaf* left_leaf = new_left->mutable_leaf();
+  model_op_->ExportModel(best->left_stats(), left_leaf);
 
   // right
   new_children->push_back(newid);
   TreeNode* new_right = tree->add_nodes();
   new_right->mutable_node_id()->set_value(newid);
-  new_right->mutable_leaf();
+  Leaf* right_leaf = new_right->mutable_leaf();
+  model_op_->ExportModel(best->right_stats(), right_leaf);
 
   node->clear_leaf();
   node->mutable_binary_node()->Swap(best->mutable_split());
@@ -72,7 +75,7 @@ void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best,
 void DecisionTreeResource::MaybeInitialize() {
   DecisionTree* tree = decision_tree_->mutable_decision_tree();
   if (tree->nodes_size() == 0) {
-    tree->add_nodes()->mutable_leaf();
+    model_op_->InitModel(tree->add_nodes()->mutable_leaf());
   } else if (node_evaluators_.empty()) {  // reconstruct evaluators
     for (const auto& node : tree->nodes()) {
       if (node.has_leaf()) {
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
index c8f09d8e075..438d3d817c4 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
@@ -31,8 +31,10 @@ namespace tensorforest {
 class DecisionTreeResource : public ResourceBase {
  public:
   // Constructor.
-  explicit DecisionTreeResource()
-      : decision_tree_(new decision_trees::Model()) {}
+  explicit DecisionTreeResource(const TensorForestParams& params)
+      : params_(params), decision_tree_(new decision_trees::Model()) {
+    model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_);
+  }
 
   string DebugString() override {
     return strings::StrCat("DecisionTree[size=",
@@ -79,7 +81,9 @@ class DecisionTreeResource : public ResourceBase {
 
  private:
   mutex mu_;
+  const TensorForestParams params_;
   std::unique_ptr<decision_trees::Model> decision_tree_;
+  std::shared_ptr<LeafModelOperator> model_op_;
   std::vector<std::unique_ptr<DecisionNodeEvaluator>> node_evaluators_;
 };
 
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc
index 5c1b7454ae6..7f914aac319 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc
@@ -20,14 +20,8 @@ namespace tensorflow {
 namespace tensorforest {
 
 void FertileStatsResource::AddExampleToStatsAndInitialize(
-    const std::unique_ptr<TensorDataSet>& input_data,
-    const InputTarget* target, const std::vector<int>& examples,
-    int32 node_id, int32 node_depth, bool* is_finished) {
-  // Set leaf's counts for calculating probabilities.
-  for (int example : examples) {
-    model_op_->UpdateModel(&leaf_stats_[node_id], target, example);
-  }
-
+    const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
+    const std::vector<int>& examples, int32 node_id, bool* is_finished) {
   // Update stats or initialize if needed.
   if (collection_op_->IsInitialized(node_id)) {
     collection_op_->AddExample(input_data, target, examples, node_id);
@@ -47,8 +41,6 @@ void FertileStatsResource::AddExampleToStatsAndInitialize(
 }
 
 void FertileStatsResource::AllocateNode(int32 node_id, int32 depth) {
-  leaf_stats_[node_id] = LeafStat();
-  model_op_->InitModel(&leaf_stats_[node_id]);
   collection_op_->InitializeSlot(node_id, depth);
 }
 
@@ -62,7 +54,6 @@ void FertileStatsResource::Allocate(int32 parent_depth,
 
 void FertileStatsResource::Clear(int32 node) {
   collection_op_->ClearSlot(node);
-  leaf_stats_.erase(node);
 }
 
 bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best,
@@ -71,27 +62,16 @@ bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best,
 }
 
 void FertileStatsResource::MaybeInitialize() {
-  if (leaf_stats_.empty()) {
-    AllocateNode(0, 0);
-  }
+  collection_op_->MaybeInitialize();
 }
 
 void FertileStatsResource::ExtractFromProto(const FertileStats& stats) {
   collection_op_ =
       SplitCollectionOperatorFactory::CreateSplitCollectionOperator(params_);
   collection_op_->ExtractFromProto(stats);
-  for (int i = 0; i < stats.node_to_slot_size(); ++i) {
-    const auto& slot = stats.node_to_slot(i);
-    leaf_stats_[slot.node_id()] = slot.leaf_stats();
-  }
 }
 
 void FertileStatsResource::PackToProto(FertileStats* stats) const {
-  for (const auto& entry : leaf_stats_) {
-    auto* slot = stats->add_node_to_slot();
-    *slot->mutable_leaf_stats() = entry.second;
-    slot->set_node_id(entry.first);
-  }
   collection_op_->PackToProto(stats);
 }
 }  // namespace tensorforest
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
index 34ec945e846..dacf033d990 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
@@ -51,7 +51,6 @@ class FertileStatsResource : public ResourceBase {
   // Resets the resource and frees the proto.
   // Caller needs to hold the mutex lock while calling this.
   void Reset() {
-    leaf_stats_.clear();
   }
 
   // Reset the stats for a node, but leave the leaf_stats intact.
@@ -71,7 +70,7 @@ class FertileStatsResource : public ResourceBase {
   void AddExampleToStatsAndInitialize(
       const std::unique_ptr<TensorDataSet>& input_data,
       const InputTarget* target, const std::vector<int>& examples,
-      int32 node_id, int32 node_depth, bool* is_finished);
+      int32 node_id, bool* is_finished);
 
   // Allocate a fertile slot for each ready node, then new children up to
   // max_fertile_nodes_.
@@ -85,19 +84,11 @@ class FertileStatsResource : public ResourceBase {
   // was found.
   bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth);
 
-  const LeafStat& leaf_stat(int32 node_id) {
-    return leaf_stats_[node_id];
-  }
-
-  void set_leaf_stat(const LeafStat& stat, int32 node_id) {
-    leaf_stats_[node_id] = stat;
-  }
 
  private:
   mutex mu_;
   std::shared_ptr<LeafModelOperator> model_op_;
   std::unique_ptr<SplitCollectionOperator> collection_op_;
-  std::unordered_map<int32, LeafStat> leaf_stats_;
   const TensorForestParams params_;
 
   void AllocateNode(int32 node_id, int32 depth);
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
index 49e425642d1..d43c068e462 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
@@ -17,6 +17,8 @@
 namespace tensorflow {
 namespace tensorforest {
 
+using decision_trees::Leaf;
+
 std::unique_ptr<LeafModelOperator>
 LeafModelOperatorFactory::CreateLeafModelOperator(
     const TensorForestParams& params) {
@@ -50,24 +52,21 @@ float DenseClassificationLeafModelOperator::GetOutputValue(
 }
 
 void DenseClassificationLeafModelOperator::UpdateModel(
-    LeafStat* leaf, const InputTarget* target,
-    int example) const {
+    Leaf* leaf, const InputTarget* target, int example) const {
   const int32 int_label = target->GetTargetAsClassIndex(example, 0);
   QCHECK_LT(int_label, params_.num_outputs())
       << "Got label greater than indicated number of classes. Is "
          "params.num_classes set correctly?";
   QCHECK_GE(int_label, 0);
-  auto* val = leaf->mutable_classification()->mutable_dense_counts()
-      ->mutable_value(int_label);
+  auto* val = leaf->mutable_vector()->mutable_value(int_label);
+
   float weight = target->GetTargetWeight(example);
   val->set_float_value(val->float_value() + weight);
-  leaf->set_weight_sum(leaf->weight_sum() + weight);
 }
 
-void DenseClassificationLeafModelOperator::InitModel(
-    LeafStat* leaf) const {
+void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const {
   for (int i = 0; i < params_.num_outputs(); ++i) {
-    leaf->mutable_classification()->mutable_dense_counts()->add_value();
+    leaf->mutable_vector()->add_value();
   }
 }
 
@@ -88,17 +87,15 @@ float SparseClassificationLeafModelOperator::GetOutputValue(
 }
 
 void SparseClassificationLeafModelOperator::UpdateModel(
-    LeafStat* leaf, const InputTarget* target,
-    int example) const {
+    Leaf* leaf, const InputTarget* target, int example) const {
   const int32 int_label = target->GetTargetAsClassIndex(example, 0);
   QCHECK_LT(int_label, params_.num_outputs())
       << "Got label greater than indicated number of classes. Is "
          "params.num_classes set correctly?";
   QCHECK_GE(int_label, 0);
   const float weight = target->GetTargetWeight(example);
-  leaf->set_weight_sum(leaf->weight_sum() + weight);
-  auto value_map = leaf->mutable_classification()->mutable_sparse_counts()
-      ->mutable_sparse_value();
+
+  auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value();
   auto it = value_map->find(int_label);
   if (it == value_map->end()) {
     (*value_map)[int_label].set_float_value(weight);
@@ -123,8 +120,8 @@ float SparseOrDenseClassificationLeafModelOperator::GetOutputValue(
 }
 
 void SparseOrDenseClassificationLeafModelOperator::UpdateModel(
-    LeafStat* leaf, const InputTarget* target, int example) const {
-  if (leaf->classification().has_dense_counts()) {
+    Leaf* leaf, const InputTarget* target, int example) const {
+  if (leaf->has_vector()) {
     return dense_->UpdateModel(leaf, target, example);
   } else {
     return sparse_->UpdateModel(leaf, target, example);
@@ -146,15 +143,15 @@ float RegressionLeafModelOperator::GetOutputValue(
   return leaf.vector().value(o).float_value();
 }
 
-void RegressionLeafModelOperator::InitModel(
-    LeafStat* leaf) const {
+void RegressionLeafModelOperator::InitModel(Leaf* leaf) const {
   for (int i = 0; i < params_.num_outputs(); ++i) {
-    leaf->mutable_regression()->mutable_mean_output()->add_value();
+    leaf->mutable_vector()->add_value();
   }
 }
 
 void RegressionLeafModelOperator::ExportModel(
     const LeafStat& stat, decision_trees::Leaf* leaf) const {
+  leaf->clear_vector();
   for (int i = 0; i < params_.num_outputs(); ++i) {
     const float new_val =
         stat.regression().mean_output().value(i).float_value() /
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h
index 8aadefc4033..946a648f22f 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h
@@ -42,12 +42,11 @@ class LeafModelOperator {
                                int32 o) const = 0;
 
   // Update the given Leaf's model with the given example.
-  virtual void UpdateModel(LeafStat* leaf,
-                           const InputTarget* target,
-                           int example) const = 0;
+  virtual void UpdateModel(decision_trees::Leaf* leaf,
+                           const InputTarget* target, int example) const = 0;
 
   // Initialize an empty Leaf model.
-  virtual void InitModel(LeafStat* leaf) const = 0;
+  virtual void InitModel(decision_trees::Leaf* leaf) const = 0;
 
   virtual void ExportModel(const LeafStat& stat,
                            decision_trees::Leaf* leaf) const = 0;
@@ -65,10 +64,10 @@ class DenseClassificationLeafModelOperator : public LeafModelOperator {
   float GetOutputValue(const decision_trees::Leaf& leaf,
                        int32 o) const override;
 
-  void UpdateModel(LeafStat* leaf, const InputTarget* target,
+  void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
                    int example) const override;
 
-  void InitModel(LeafStat* leaf) const override;
+  void InitModel(decision_trees::Leaf* leaf) const override;
 
   void ExportModel(const LeafStat& stat,
                    decision_trees::Leaf* leaf) const override;
@@ -84,10 +83,10 @@ class SparseClassificationLeafModelOperator : public LeafModelOperator {
   float GetOutputValue(const decision_trees::Leaf& leaf,
                        int32 o) const override;
 
-  void UpdateModel(LeafStat* leaf, const InputTarget* target,
+  void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
                    int example) const override;
 
-  void InitModel(LeafStat* leaf) const override {}
+  void InitModel(decision_trees::Leaf* leaf) const override {}
 
   void ExportModel(const LeafStat& stat,
                    decision_trees::Leaf* leaf) const override;
@@ -103,10 +102,10 @@ class SparseOrDenseClassificationLeafModelOperator : public LeafModelOperator {
   float GetOutputValue(const decision_trees::Leaf& leaf,
                        int32 o) const override;
 
-  void UpdateModel(LeafStat* leaf, const InputTarget* target,
+  void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
                    int example) const override;
 
-  void InitModel(LeafStat* leaf) const override {}
+  void InitModel(decision_trees::Leaf* leaf) const override {}
 
   void ExportModel(const LeafStat& stat,
                    decision_trees::Leaf* leaf) const override;
@@ -129,10 +128,10 @@ class RegressionLeafModelOperator : public LeafModelOperator {
   // updating model and just using the seeded values.  Can add this in
   // with additional_data, though protobuf::Any is slow.  Maybe make it
   // optional.  Maybe make any update optional.
-  void UpdateModel(LeafStat* leaf, const InputTarget* target,
+  void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
                    int example) const override {}
 
-  void InitModel(LeafStat* leaf) const override;
+  void InitModel(decision_trees::Leaf* leaf) const override;
 
   void ExportModel(const LeafStat& stat,
                    decision_trees::Leaf* leaf) const override;
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc
index 35268d15d3e..ffd92c01f9a 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc
@@ -63,12 +63,8 @@ constexpr char kRegressionStatProto[] =
 "}";
 
 void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) {
-  std::unique_ptr<LeafStat> leaf(new LeafStat);
-  op->InitModel(leaf.get());
-
   Leaf l;
-  op->ExportModel(*leaf, &l);
-
+  op->InitModel(&l);
   // Make sure it was initialized correctly.
   for (int i = 0; i < kNumClasses; ++i) {
     EXPECT_EQ(op->GetOutputValue(l, i), 0);
@@ -80,11 +76,10 @@ void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) {
       new TestableInputTarget(labels, weights, 1));
 
   // Update and check value.
-  op->UpdateModel(leaf.get(), target.get(), 0);
-  op->UpdateModel(leaf.get(), target.get(), 1);
-  op->UpdateModel(leaf.get(), target.get(), 2);
+  op->UpdateModel(&l, target.get(), 0);
+  op->UpdateModel(&l, target.get(), 1);
+  op->UpdateModel(&l, target.get(), 2);
 
-  op->ExportModel(*leaf, &l);
   EXPECT_FLOAT_EQ(op->GetOutputValue(l, 1), 3.4);
 }
 
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
index 632408fd718..ccc412600c7 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
@@ -71,13 +71,13 @@ void SplitCollectionOperator::ExtractFromProto(
 }
 
 void SplitCollectionOperator::PackToProto(FertileStats* stats_proto) const {
-  for (int i = 0; i < stats_proto->node_to_slot_size(); ++i) {
-    auto* new_slot = stats_proto->mutable_node_to_slot(i);
-    const auto& stats = stats_.at(new_slot->node_id());
+  for (const auto& pair : stats_) {
+    auto* new_slot = stats_proto->add_node_to_slot();
+    new_slot->set_node_id(pair.first);
     if (params_.checkpoint_stats()) {
-      stats->PackToProto(new_slot);
+      pair.second->PackToProto(new_slot);
     }
-    new_slot->set_depth(stats->depth());
+    new_slot->set_depth(pair.second->depth());
   }
 }
 
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
index 6990e82678b..6c21c0bd344 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
@@ -62,6 +62,14 @@ class SplitCollectionOperator {
   // Create a new GrowStats for the given node id and initialize it.
   virtual void InitializeSlot(int32 node_id, int32 depth);
 
+  // Called when the resource is deserialized, possibly needing an
+  // initialization.
+  virtual void MaybeInitialize() {
+    if (stats_.empty()) {
+      InitializeSlot(0, 0);
+    }
+  }
+
   // Perform any necessary cleanup for any tracked state for the slot.
   virtual void ClearSlot(int32 node_id) {
     stats_.erase(node_id);
diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc
index 168f079f523..1227a70a2e9 100644
--- a/tensorflow/contrib/tensor_forest/ops/model_ops.cc
+++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc
@@ -115,6 +115,58 @@ sparse_input_shape: The shape tensor from the SparseTensor input.
 predictions: `predictions[i][j]` is the probability that input i is class j.
 )doc");
 
+REGISTER_OP("TraverseTreeV4")
+    .Attr("input_spec: string")
+    .Attr("params: string")
+    .Input("tree_handle: resource")
+    .Input("input_data: float")
+    .Input("sparse_input_indices: int64")
+    .Input("sparse_input_values: float")
+    .Input("sparse_input_shape: int64")
+    .Output("leaf_ids: int32")
+    .SetShapeFn([](InferenceContext* c) {
+      DimensionHandle num_points = c->UnknownDim();
+
+      if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0 &&
+          c->Value(c->Dim(c->input(1), 0)) > 0) {
+        num_points = c->Dim(c->input(1), 0);
+      }
+
+      c->set_output(0, c->Vector(num_points));
+      return Status::OK();
+    })
+    .Doc(R"doc(
+Outputs the leaf ids for the given input data.
+
+params: A serialized TensorForestParams proto.
+tree_handle: The handle to the tree.
+input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
+   gives the j-th feature of the i-th input.
+sparse_input_indices: The indices tensor from the SparseTensor input.
+sparse_input_values: The values tensor from the SparseTensor input.
+sparse_input_shape: The shape tensor from the SparseTensor input.
+leaf_ids: `leaf_ids[i]` is the leaf id for input i.
+)doc");
+
+REGISTER_OP("UpdateModelV4")
+    .Attr("params: string")
+    .Input("tree_handle: resource")
+    .Input("leaf_ids: int32")
+    .Input("input_labels: float")
+    .Input("input_weights: float")
+    .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+    .Doc(R"doc(
+Updates the given leaves for each example with the new labels.
+
+params: A serialized TensorForestParams proto.
+tree_handle: The handle to the tree.
+leaf_ids: `leaf_ids[i]` is the leaf id for input i.
+input_labels: The training batch's labels as a 1 or 2-d tensor.
+  'input_labels[i][j]' gives the j-th label/target for the i-th input.
+input_weights: The training batch's eample weights as a 1-d tensor.
+  'input_weights[i]' gives the weight for the i-th input.
+)doc");
+
 REGISTER_OP("FeatureUsageCounts")
     .Attr("params: string")
     .Input("tree_handle: resource")
diff --git a/tensorflow/contrib/tensor_forest/ops/stats_ops.cc b/tensorflow/contrib/tensor_forest/ops/stats_ops.cc
index 96527497689..e8b5c5d8a6e 100644
--- a/tensorflow/contrib/tensor_forest/ops/stats_ops.cc
+++ b/tensorflow/contrib/tensor_forest/ops/stats_ops.cc
@@ -98,6 +98,7 @@ REGISTER_OP("ProcessInputV4")
     .Input("sparse_input_shape: int64")
     .Input("input_labels: float")
     .Input("input_weights: float")
+    .Input("leaf_ids: int32")
     .Output("finished_nodes: int32")
     .SetShapeFn([](InferenceContext* c) {
       c->set_output(0, c->Vector(c->UnknownDim()));
@@ -122,6 +123,7 @@ input_weights: The training batch's eample weights as a 1-d tensor.
   'input_weights[i]' gives the weight for the i-th input.
 finished_nodes: A 1-d tensor of node ids that have finished and are ready to
   grow.
+leaf_ids: `leaf_ids[i]` is the leaf id for input i.
 )doc");
 
 REGISTER_OP("FinalizeTree")
diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py
index 4c7218305b5..d240e2f6dec 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py
@@ -18,12 +18,13 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.tensor_forest.python.ops import gen_model_ops
-from tensorflow.contrib.tensor_forest.python.ops import stats_ops
 
 # pylint: disable=unused-import
 from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import feature_usage_counts
+from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import traverse_tree_v4
 from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_predictions_v4
 from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_size
+from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import update_model_v4
 # pylint: enable=unused-import
 
 from tensorflow.contrib.util import loader
@@ -59,13 +60,7 @@ class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject):
       name: the name to save the tree variable under.
     """
     self.params = params
-    deps = []
-    if stats_handle is not None:
-      deps.append(stats_ops.finalize_tree(
-          tree_handle, stats_handle,
-          params=params.serialized_params_proto))
-    with ops.control_dependencies(deps):
-      tensor = gen_model_ops.tree_serialize(tree_handle)
+    tensor = gen_model_ops.tree_serialize(tree_handle)
     # slice_spec is useful for saving a slice from a variable.
     # It's not meaningful the tree variable. So we just pass an empty value.
     slice_spec = ""
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py
index 7e6f00a13df..8198c228dd6 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py
@@ -27,6 +27,7 @@ from tensorflow.contrib.tensor_forest.python import tensor_forest
 from tensorflow.contrib.tensor_forest.python.ops import model_ops
 from tensorflow.contrib.tensor_forest.python.ops import stats_ops
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.platform import tf_logging as logging
 
@@ -240,6 +241,22 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs):
     if input_data is None:
       input_data = []
 
+    leaf_ids = model_ops.traverse_tree_v4(
+        self.variables.tree,
+        input_data,
+        sparse_indices,
+        sparse_values,
+        sparse_shape,
+        input_spec=data_spec.SerializeToString(),
+        params=self.params.serialized_params_proto)
+
+    update_model = model_ops.update_model_v4(
+        self.variables.tree,
+        leaf_ids,
+        input_labels,
+        input_weights,
+        params=self.params.serialized_params_proto)
+
     finished_nodes = stats_ops.process_input_v4(
         self.variables.tree,
         self.variables.stats,
@@ -249,13 +266,17 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs):
         sparse_shape,
         input_labels,
         input_weights,
+        leaf_ids,
         input_spec=data_spec.SerializeToString(),
         random_seed=random_seed,
         params=self.params.serialized_params_proto)
 
-    return stats_ops.grow_tree_v4(self.variables.tree, self.variables.stats,
-                                  finished_nodes,
-                                  params=self.params.serialized_params_proto)
+    with ops.control_dependencies([update_model]):
+      return stats_ops.grow_tree_v4(
+          self.variables.tree,
+          self.variables.stats,
+          finished_nodes,
+          params=self.params.serialized_params_proto)
 
   def inference_graph(self, input_data, data_spec, sparse_features=None):
     sparse_indices = []