More debug print

This commit is contained in:
Mihai Maruseac 2020-12-22 16:31:57 -08:00
parent 8be2cc5149
commit 289a0ba853

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys import sys
_p = print _p = print
_p("IMPORTING", file=sys.stderr)
from google.protobuf import text_format from google.protobuf import text_format
@ -48,11 +49,11 @@ class ResourceOpsTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testCreateWithProto(self): def testCreateWithProto(self):
_p("testCreateWithProto 1") _p("testCreateWithProto 1", file=sys.stderr)
with self.cached_session(): with self.cached_session():
_p("testCreateWithProto 2") _p("testCreateWithProto 2", file=sys.stderr)
ensemble_proto = boosted_trees_pb2.TreeEnsemble() ensemble_proto = boosted_trees_pb2.TreeEnsemble()
_p("testCreateWithProto 3") _p("testCreateWithProto 3", file=sys.stderr)
text_format.Merge( text_format.Merge(
""" """
trees { trees {
@ -154,27 +155,27 @@ class ResourceOpsTest(test_util.TensorFlowTestCase):
last_layer_node_end: 19 last_layer_node_end: 19
} }
""", ensemble_proto) """, ensemble_proto)
_p("testCreateWithProto 4") _p("testCreateWithProto 4", file=sys.stderr)
ensemble = boosted_trees_ops.TreeEnsemble( ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', 'ensemble',
stamp_token=7, stamp_token=7,
serialized_proto=ensemble_proto.SerializeToString()) serialized_proto=ensemble_proto.SerializeToString())
_p("testCreateWithProto 5") _p("testCreateWithProto 5", file=sys.stderr)
resources.initialize_resources(resources.shared_resources()).run() resources.initialize_resources(resources.shared_resources()).run()
_p("testCreateWithProto 6") _p("testCreateWithProto 6", file=sys.stderr)
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers, (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
nodes_range) = ensemble.get_states() nodes_range) = ensemble.get_states()
_p("testCreateWithProto 7") _p("testCreateWithProto 7", file=sys.stderr)
self.assertEqual(7, self.evaluate(stamp_token)) self.assertEqual(7, self.evaluate(stamp_token))
_p("testCreateWithProto 8") _p("testCreateWithProto 8", file=sys.stderr)
self.assertEqual(2, self.evaluate(num_trees)) self.assertEqual(2, self.evaluate(num_trees))
_p("testCreateWithProto 9") _p("testCreateWithProto 9", file=sys.stderr)
self.assertEqual(1, self.evaluate(num_finalized_trees)) self.assertEqual(1, self.evaluate(num_finalized_trees))
_p("testCreateWithProto 10") _p("testCreateWithProto 10", file=sys.stderr)
self.assertEqual(6, self.evaluate(num_attempted_layers)) self.assertEqual(6, self.evaluate(num_attempted_layers))
_p("testCreateWithProto 11") _p("testCreateWithProto 11", file=sys.stderr)
self.assertAllEqual([16, 19], self.evaluate(nodes_range)) self.assertAllEqual([16, 19], self.evaluate(nodes_range))
_p("testCreateWithProto 12") _p("testCreateWithProto 12", file=sys.stderr)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testSerializeDeserialize(self): def testSerializeDeserialize(self):