diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index 4e5f1db7e02..82d3601a6a8 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -253,15 +253,24 @@ void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
 }
 
 // Add a tree to the ensemble and returns a new tree_id.
-int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) {
-  return AddNewTreeWithLogits(weight, 0.0);
+int32 BoostedTreesEnsembleResource::AddNewTree(const float weight,
+                                               const int32 logits_dimension) {
+  const std::vector<float> empty_leaf(logits_dimension);
+  return AddNewTreeWithLogits(weight, empty_leaf, logits_dimension);
 }
 
-int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight,
-                                                         const float logits) {
+int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(
+    const float weight, const std::vector<float>& logits,
+    const int32 logits_dimension) {
   const int32 new_tree_id = tree_ensemble_->trees_size();
   auto* node = tree_ensemble_->add_trees()->add_nodes();
-  node->mutable_leaf()->set_scalar(logits);
+  if (logits_dimension == 1) {
+    node->mutable_leaf()->set_scalar(logits[0]);
+  } else {
+    for (int32 i = 0; i < logits_dimension; ++i) {
+      node->mutable_leaf()->mutable_vector()->add_value(logits[i]);
+    }
+  }
   tree_ensemble_->add_tree_weights(weight);
   tree_ensemble_->add_tree_metadata();
 
diff --git a/tensorflow/core/kernels/boosted_trees/resources.h b/tensorflow/core/kernels/boosted_trees/resources.h
index 572b14757cf..70155e89071 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.h
+++ b/tensorflow/core/kernels/boosted_trees/resources.h
@@ -102,10 +102,12 @@ class BoostedTreesEnsembleResource : public StampedResource {
   int32 right_id(const int32 tree_id, const int32 node_id) const;
 
   // Add a tree to the ensemble and returns a new tree_id.
-  int32 AddNewTree(const float weight);
+  int32 AddNewTree(const float weight, const int32 logits_dimension);
 
   // Adds new tree with one node to the ensemble and sets node's value to logits
-  int32 AddNewTreeWithLogits(const float weight, const float logits);
+  int32 AddNewTreeWithLogits(const float weight,
+                             const std::vector<float>& logits,
+                             const int32 logits_dimension);
 
   // Grows the tree by adding a bucketized split and leaves.
   void AddBucketizedSplitNode(
diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc
index ca4f2e011be..7816c2c07eb 100644
--- a/tensorflow/core/kernels/boosted_trees/training_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/training_ops.cc
@@ -136,7 +136,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
         }
         if (ensemble_resource->num_trees() > 0) {
           // Create a dummy new tree with an empty node.
-          ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
+          ensemble_resource->AddNewTree(kLayerByLayerTreeWeight, 1);
         }
       }
       // If we managed to split, update the node range. If we didn't, don't
@@ -159,7 +159,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
     // boosting.
     if (num_trees <= 0) {
       // Create a new tree with a no-op leaf.
-      current_tree = resource->AddNewTree(kLayerByLayerTreeWeight);
+      current_tree = resource->AddNewTree(kLayerByLayerTreeWeight, 1);
     }
     return current_tree;
   }
@@ -357,7 +357,7 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
         }
         if (ensemble_resource->num_trees() > 0) {
           // Create a dummy new tree with an empty node.
-          ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
+          ensemble_resource->AddNewTree(kLayerByLayerTreeWeight, logits_dim_);
         }
       }
       // If we managed to split, update the node range. If we didn't, don't
@@ -380,7 +380,7 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
     // boosting.
     if (num_trees <= 0) {
       // Create a new tree with a no-op leaf.
-      current_tree = resource->AddNewTree(kLayerByLayerTreeWeight);
+      current_tree = resource->AddNewTree(kLayerByLayerTreeWeight, logits_dim_);
     }
     return current_tree;
   }
@@ -504,7 +504,8 @@ class BoostedTreesCenterBiasOp : public OpKernel {
     float current_bias = 0.0;
     bool continue_centering = true;
     if (ensemble_resource->num_trees() == 0) {
-      ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, logits);
+      ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, {logits},
+                                              1);
       current_bias = logits;
     } else {
       const auto& current_biases = ensemble_resource->node_value(0, 0);
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index e7961fc4c07..ed554ea9288 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -2459,9 +2459,20 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
             }
           }
         }
+        trees {
+          nodes {
+            leaf {
+              vector {
+                value: 0
+                value: 0
+              }
+            }
+          }
+        }
         tree_weights: 0.1
         tree_weights: 0.2
         tree_weights: 1.0
+        tree_weights: 1.0
       """, tree_ensemble_config)
 
       # Create existing ensemble with one root split
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
index 3713fd289da..5e82fe44316 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
@@ -438,6 +438,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
             }
             metadata {
               gain: 7.65
+              original_leaf {
+                vector {
+                  value: 0.0
+                  value: 0.0
+                }
+              }
             }
           }
           nodes {
@@ -464,7 +470,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
         trees {
           nodes {
             leaf {
-              scalar: 0.0
+              vector {
+                value: 0.0
+                value: 0.0
+              }
             }
           }
         }
@@ -1103,7 +1112,6 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
         trees {
           nodes {
             leaf {
-              scalar: 0.0
             }
           }
         }
@@ -1352,7 +1360,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
         trees {
           nodes {
             leaf {
-              scalar: 0.0
+              vector {
+                value: 0.0
+                value: 0.0
+              }
             }
           }
         }
@@ -3065,6 +3076,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
             }
             metadata {
               gain: -0.2
+              original_leaf {
+                vector {
+                  value: 0.0
+                  value: 0.0
+                }
+              }
             }
           }
           nodes {
@@ -3152,6 +3169,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
             }
             metadata {
               gain: -0.2
+              original_leaf {
+                vector {
+                  value: 0.0
+                  value: 0.0
+                }
+              }
             }
           }
           nodes {
@@ -3301,6 +3324,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
             }
             metadata {
               gain: -0.2
+              original_leaf {
+                vector {
+                  value: 0.0
+                  value: 0.0
+                }
+              }
             }
           }
           nodes {
@@ -3357,6 +3386,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
         trees {
           nodes {
             leaf {
+              vector {
+                value: 0
+                value: 0
+              }
             }
           }
         }
@@ -3682,6 +3715,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
             }
             metadata {
               gain: -0.62
+              original_leaf {
+                vector {
+                  value: 0.0
+                  value: 0.0
+                }
+              }
             }
           }
           nodes {
@@ -3766,12 +3805,20 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
       trees {
         nodes {
           leaf {
+            vector {
+              value: 0
+              value: 0
+            }
           }
         }
       }
       trees {
         nodes {
           leaf {
+            vector {
+              value: 0
+              value: 0
+            }
           }
         }
       }
@@ -4000,6 +4047,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
             }
             metadata {
               gain: 7.62
+              original_leaf {
+                vector {
+                  value: 0.0
+                  value: 0.0
+                }
+              }
             }
           }
           nodes {
@@ -4026,6 +4079,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
         trees {
           nodes {
             leaf {
+              vector {
+                value: 0
+                value: 0
+              }
             }
           }
         }