Try only one test
This commit is contained in:
parent
289a0ba853
commit
9bcec1c979
@ -21,32 +21,25 @@ _p = print
|
||||
_p("IMPORTING", file=sys.stderr)
|
||||
|
||||
from google.protobuf import text_format
|
||||
_p("IMPORTING 2", file=sys.stderr)
|
||||
|
||||
from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
|
||||
_p("IMPORTING 3", file=sys.stderr)
|
||||
from tensorflow.python.framework import ops
|
||||
_p("IMPORTING 4", file=sys.stderr)
|
||||
from tensorflow.python.framework import test_util
|
||||
_p("IMPORTING 5", file=sys.stderr)
|
||||
from tensorflow.python.ops import boosted_trees_ops
|
||||
_p("IMPORTING 6", file=sys.stderr)
|
||||
from tensorflow.python.ops import resources
|
||||
_p("IMPORTING 7", file=sys.stderr)
|
||||
from tensorflow.python.platform import googletest
|
||||
_p("IMPORTING 8", file=sys.stderr)
|
||||
|
||||
|
||||
class ResourceOpsTest(test_util.TensorFlowTestCase):
|
||||
"""Tests resource_ops."""
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testCreate(self):
|
||||
with self.cached_session():
|
||||
ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
stamp_token = ensemble.get_stamp_token()
|
||||
self.assertEqual(0, self.evaluate(stamp_token))
|
||||
(_, num_trees, num_finalized_trees, num_attempted_layers,
|
||||
nodes_range) = ensemble.get_states()
|
||||
self.assertEqual(0, self.evaluate(num_trees))
|
||||
self.assertEqual(0, self.evaluate(num_finalized_trees))
|
||||
self.assertEqual(0, self.evaluate(num_attempted_layers))
|
||||
self.assertAllEqual([0, 1], self.evaluate(nodes_range))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testCreateWithProto(self):
|
||||
_p("testCreateWithProto 1", file=sys.stderr)
|
||||
@ -177,81 +170,6 @@ class ResourceOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual([16, 19], self.evaluate(nodes_range))
|
||||
_p("testCreateWithProto 12", file=sys.stderr)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSerializeDeserialize(self):
|
||||
with self.cached_session():
|
||||
# Initialize.
|
||||
ensemble = boosted_trees_ops.TreeEnsemble('ensemble', stamp_token=5)
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
|
||||
nodes_range) = ensemble.get_states()
|
||||
self.assertEqual(5, self.evaluate(stamp_token))
|
||||
self.assertEqual(0, self.evaluate(num_trees))
|
||||
self.assertEqual(0, self.evaluate(num_finalized_trees))
|
||||
self.assertEqual(0, self.evaluate(num_attempted_layers))
|
||||
self.assertAllEqual([0, 1], self.evaluate(nodes_range))
|
||||
|
||||
# Deserialize.
|
||||
ensemble_proto = boosted_trees_pb2.TreeEnsemble()
|
||||
text_format.Merge(
|
||||
"""
|
||||
trees {
|
||||
nodes {
|
||||
bucketized_split {
|
||||
feature_id: 75
|
||||
threshold: 21
|
||||
left_id: 1
|
||||
right_id: 2
|
||||
}
|
||||
metadata {
|
||||
gain: -1.4
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
scalar: -0.6
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
scalar: 0.165
|
||||
}
|
||||
}
|
||||
}
|
||||
tree_weights: 0.5
|
||||
tree_metadata {
|
||||
num_layers_grown: 4 # it's fake intentionally.
|
||||
is_finalized: false
|
||||
}
|
||||
growing_metadata {
|
||||
num_trees_attempted: 1
|
||||
num_layers_attempted: 5
|
||||
last_layer_node_start: 3
|
||||
last_layer_node_end: 7
|
||||
}
|
||||
""", ensemble_proto)
|
||||
with ops.control_dependencies([
|
||||
ensemble.deserialize(
|
||||
stamp_token=3,
|
||||
serialized_proto=ensemble_proto.SerializeToString())
|
||||
]):
|
||||
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
|
||||
nodes_range) = ensemble.get_states()
|
||||
self.assertEqual(3, self.evaluate(stamp_token))
|
||||
self.assertEqual(1, self.evaluate(num_trees))
|
||||
# This reads from metadata, not really counting the layers.
|
||||
self.assertEqual(5, self.evaluate(num_attempted_layers))
|
||||
self.assertEqual(0, self.evaluate(num_finalized_trees))
|
||||
self.assertAllEqual([3, 7], self.evaluate(nodes_range))
|
||||
|
||||
|
||||
# Serialize.
|
||||
new_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
|
||||
new_stamp_token, new_serialized = ensemble.serialize()
|
||||
self.assertEqual(3, self.evaluate(new_stamp_token))
|
||||
new_ensemble_proto.ParseFromString(new_serialized.eval())
|
||||
self.assertProtoEquals(ensemble_proto, new_ensemble_proto)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user