diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc index a55e2dbc723..ded6c5a151f 100644 --- a/tensorflow/core/ops/boosted_trees_ops.cc +++ b/tensorflow/core/ops/boosted_trees_ops.cc @@ -645,24 +645,29 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2") shape_inference::ShapeHandle shape_handle; for (int i = 0; i < num_groups; ++i) { int offset = i + 1; + // Feature ids TF_RETURN_IF_ERROR(c->WithRank(c->input(offset), 1, &shape_handle)); + // TODO(nponomareva): replace this with input("name",vector of shapes). + auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)}); + TF_RETURN_IF_ERROR( + c->Merge(c->input(offset), shape_rank_1, &shape_handle)); // Dimension ids. TF_RETURN_IF_ERROR( c->WithRank(c->input(offset + num_features), 1, &shape_handle)); + TF_RETURN_IF_ERROR( + c->Merge(c->input(offset), shape_rank_1, &shape_handle)); // Node ids. TF_RETURN_IF_ERROR( c->WithRank(c->input(offset + num_features * 2), 1, &shape_handle)); - auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)}); - auto shape_rank_2 = - c->MakeShape({c->Dim(shape_handle, 0), logits_dimension}); + TF_RETURN_IF_ERROR( + c->Merge(c->input(offset), shape_rank_1, &shape_handle)); // Gains. TF_RETURN_IF_ERROR( c->WithRank(c->input(offset + num_features * 3), 1, &shape_handle)); - // TODO(nponomareva): replace this with input("name",vector of shapes). TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 3), shape_rank_1, &shape_handle)); @@ -673,6 +678,8 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2") shape_rank_1, &shape_handle)); // Left and right node contribs. + auto shape_rank_2 = + c->MakeShape({c->Dim(shape_handle, 0), logits_dimension}); TF_RETURN_IF_ERROR( c->WithRank(c->input(offset + num_features * 5), 2, &shape_handle)); TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 5), diff --git a/tensorflow/python/kernel_tests/boosted_trees/BUILD b/tensorflow/python/kernel_tests/boosted_trees/BUILD index a802d7fc05a..5b318324d4c 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/BUILD +++ b/tensorflow/python/kernel_tests/boosted_trees/BUILD @@ -69,7 +69,6 @@ tf_py_test( name = "training_ops_test", size = "small", srcs = ["training_ops_test.py"], - tags = ["noasan"], # b/148159528 deps = [ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py index 88282001abc..fbac51ea1fb 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py @@ -38,7 +38,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowWithEmptyEnsemble(self): """Test growing an empty ensemble.""" with self.cached_session() as session: - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble') tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() @@ -148,7 +148,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowWithEmptyEnsembleV2(self): """Test growing an empty ensemble.""" with self.cached_session() as session: - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble') tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() @@ -257,7 +257,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowWithEmptyEnsembleV2EqualitySplit(self): """Test growing an empty ensemble.""" with self.cached_session() as session: - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble') tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() @@ -366,7 +366,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowWithEmptyEnsembleV2MultiClass(self): """Test growing an empty ensemble for multi-class case.""" with self.cached_session() as session: - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble') tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() @@ -500,7 +500,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testBiasCenteringOnEmptyEnsemble(self): """Test growing with bias centering on an empty ensemble.""" with self.cached_session() as session: - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble') tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() @@ -2665,7 +2665,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testPostPruningOfSomeNodes(self): """Test growing an ensemble with post-pruning.""" with self.cached_session() as session: - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) @@ -3000,7 +3000,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testPostPruningOfSomeNodesMultiClassV2(self): """Test growing an ensemble with post-pruning.""" with self.cached_session() as session: - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) @@ -3454,8 +3454,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testPostPruningOfAllNodes(self): """Test growing an ensemble with post-pruning, with all nodes are pruned.""" with self.cached_session() as session: - # Create empty ensemble. - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) @@ -3638,7 +3637,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testPostPruningOfAllNodesMultiClassV2(self): """Test growing an ensemble with post-pruning, with all nodes are pruned.""" with self.cached_session() as session: - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) @@ -3753,7 +3752,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): # Prepare inputs. # All have negative gain. - group1_feature_ids = [3] + group1_feature_ids = [3, 0] group1_nodes = np.array([1, 2], dtype=np.int32) group1_gains = np.array([-0.2, -0.5], dtype=np.float32) group1_dimensions = np.array([0, 4], dtype=np.int32) @@ -3869,7 +3868,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testPostPruningChangesNothing(self): """Test growing an ensemble with post-pruning with all gains >0.""" with self.cached_session() as session: - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) @@ -3970,7 +3969,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testPostPruningChangesNothingMultiClassV2(self): """Test growing an ensemble with post-pruning with all gains >0.""" with self.cached_session() as session: - # Create empty ensemble. + # Create an empty ensemble. tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) @@ -4095,6 +4094,48 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual(new_stamp, 1) self.assertProtoEquals(expected_result, res_ensemble) + @test_util.run_deprecated_v1 + def testMismatchedInputLength(self): + """Tests raises invalid argument error when input list lengths mismatch.""" + with self.cached_session() as session: + # Create an empty ensemble. + tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() + tree_ensemble = boosted_trees_ops.TreeEnsemble( + 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare inputs. + length_one_feature_ids = [3] # Should be length 2 to match others. + nodes = np.array([1, 2], dtype=np.int32) + gains = np.array([-0.2, -0.5], dtype=np.float32) + dimensions = np.array([0, 4], dtype=np.int32) + thresholds = np.array([77, 79], dtype=np.int32) + left_node_contribs = np.array([[0.023, -0.99], [0.3, 5.979]], + dtype=np.float32) + right_node_contribs = np.array([[0.012343, 0.63], [24, 0.289]], + dtype=np.float32) + split_types = np.array( + [_INEQUALITY_DEFAULT_LEFT, _INEQUALITY_DEFAULT_LEFT]) + with self.assertRaisesRegexp(Exception, + r'Dimension 0 in both shapes must be equal'): + grow_op = boosted_trees_ops.update_ensemble_v2( + tree_ensemble_handle, + learning_rate=1.0, + pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING, + max_depth=2, + feature_ids=[length_one_feature_ids], + dimension_ids=[dimensions], + node_ids=[nodes], + gains=[gains], + thresholds=[thresholds], + left_node_contribs=[left_node_contribs], + right_node_contribs=[right_node_contribs], + split_types=[split_types], + logits_dimension=2) + + session.run(grow_op) + if __name__ == '__main__': googletest.main()