diff --git a/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py index 4f69424386d..2911fe733ba 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py @@ -39,10 +39,80 @@ class ResourceOpsTest(test_util.TensorFlowTestCase): ensemble = boosted_trees_ops.TreeEnsemble('ensemble') resources.initialize_resources(resources.shared_resources()).run() _p("testCreate 2", file=sys.stderr) + + @test_util.run_deprecated_v1 + def testCreateWithProto(self): + _p("testCreateWithProto 1", file=sys.stderr) with self.cached_session(): - ensemble = boosted_trees_ops.TreeEnsemble('ensemble') + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + text_format.Merge( + """ + trees { + nodes { + bucketized_split { + feature_id: 4 + left_id: 1 + right_id: 2 + } + metadata { + gain: 7.62 + } + } + nodes { + bucketized_split { + threshold: 21 + left_id: 3 + right_id: 4 + } + metadata { + gain: 1.4 + original_leaf { + scalar: 7.14 + } + } + } + nodes { + bucketized_split { + feature_id: 1 + threshold: 7 + left_id: 5 + right_id: 6 + } + metadata { + gain: 2.7 + original_leaf { + scalar: -4.375 + } + } + } + nodes { + leaf { + scalar: 6.54 + } + } + nodes { + leaf { + scalar: 7.305 + } + } + nodes { + leaf { + scalar: -4.525 + } + } + nodes { + leaf { + scalar: -4.145 + } + } + } + """, ensemble_proto) + ensemble = boosted_trees_ops.TreeEnsemble( + 'ensemble', + stamp_token=7, + serialized_proto=ensemble_proto.SerializeToString()) resources.initialize_resources(resources.shared_resources()).run() - _p("testCreate 3", file=sys.stderr) + _p("testCreateWithProto 2", file=sys.stderr) if __name__ == '__main__':