Try only one test

This commit is contained in:
Mihai Maruseac 2020-12-22 17:52:17 -08:00
parent 289a0ba853
commit 9bcec1c979

View File

@ -21,32 +21,25 @@ _p = print
_p("IMPORTING", file=sys.stderr)
from google.protobuf import text_format
_p("IMPORTING 2", file=sys.stderr)
from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
_p("IMPORTING 3", file=sys.stderr)
from tensorflow.python.framework import ops
_p("IMPORTING 4", file=sys.stderr)
from tensorflow.python.framework import test_util
_p("IMPORTING 5", file=sys.stderr)
from tensorflow.python.ops import boosted_trees_ops
_p("IMPORTING 6", file=sys.stderr)
from tensorflow.python.ops import resources
_p("IMPORTING 7", file=sys.stderr)
from tensorflow.python.platform import googletest
_p("IMPORTING 8", file=sys.stderr)
class ResourceOpsTest(test_util.TensorFlowTestCase):
"""Tests resource_ops."""
@test_util.run_deprecated_v1
def testCreate(self):
with self.cached_session():
ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
resources.initialize_resources(resources.shared_resources()).run()
stamp_token = ensemble.get_stamp_token()
self.assertEqual(0, self.evaluate(stamp_token))
(_, num_trees, num_finalized_trees, num_attempted_layers,
nodes_range) = ensemble.get_states()
self.assertEqual(0, self.evaluate(num_trees))
self.assertEqual(0, self.evaluate(num_finalized_trees))
self.assertEqual(0, self.evaluate(num_attempted_layers))
self.assertAllEqual([0, 1], self.evaluate(nodes_range))
@test_util.run_deprecated_v1
def testCreateWithProto(self):
_p("testCreateWithProto 1", file=sys.stderr)
@ -177,81 +170,6 @@ class ResourceOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([16, 19], self.evaluate(nodes_range))
_p("testCreateWithProto 12", file=sys.stderr)
@test_util.run_deprecated_v1
def testSerializeDeserialize(self):
with self.cached_session():
# Initialize.
ensemble = boosted_trees_ops.TreeEnsemble('ensemble', stamp_token=5)
resources.initialize_resources(resources.shared_resources()).run()
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
nodes_range) = ensemble.get_states()
self.assertEqual(5, self.evaluate(stamp_token))
self.assertEqual(0, self.evaluate(num_trees))
self.assertEqual(0, self.evaluate(num_finalized_trees))
self.assertEqual(0, self.evaluate(num_attempted_layers))
self.assertAllEqual([0, 1], self.evaluate(nodes_range))
# Deserialize.
ensemble_proto = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
trees {
nodes {
bucketized_split {
feature_id: 75
threshold: 21
left_id: 1
right_id: 2
}
metadata {
gain: -1.4
}
}
nodes {
leaf {
scalar: -0.6
}
}
nodes {
leaf {
scalar: 0.165
}
}
}
tree_weights: 0.5
tree_metadata {
num_layers_grown: 4 # it's fake intentionally.
is_finalized: false
}
growing_metadata {
num_trees_attempted: 1
num_layers_attempted: 5
last_layer_node_start: 3
last_layer_node_end: 7
}
""", ensemble_proto)
with ops.control_dependencies([
ensemble.deserialize(
stamp_token=3,
serialized_proto=ensemble_proto.SerializeToString())
]):
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
nodes_range) = ensemble.get_states()
self.assertEqual(3, self.evaluate(stamp_token))
self.assertEqual(1, self.evaluate(num_trees))
# This reads from metadata, not really counting the layers.
self.assertEqual(5, self.evaluate(num_attempted_layers))
self.assertEqual(0, self.evaluate(num_finalized_trees))
self.assertAllEqual([3, 7], self.evaluate(nodes_range))
# Serialize.
new_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
new_stamp_token, new_serialized = ensemble.serialize()
self.assertEqual(3, self.evaluate(new_stamp_token))
new_ensemble_proto.ParseFromString(new_serialized.eval())
self.assertProtoEquals(ensemble_proto, new_ensemble_proto)
if __name__ == '__main__':
googletest.main()