Fix failing ASan test due to out of boundary memory access
PiperOrigin-RevId: 292701009 Change-Id: I2f964c9958f0b2863071707a463f03b3f4e16c11
This commit is contained in:
parent
d15e205f9e
commit
ea437ce27c
@ -645,24 +645,29 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2")
|
||||
shape_inference::ShapeHandle shape_handle;
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int offset = i + 1;
|
||||
|
||||
// Feature ids
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(offset), 1, &shape_handle));
|
||||
// TODO(nponomareva): replace this with input("name",vector of shapes).
|
||||
auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->input(offset), shape_rank_1, &shape_handle));
|
||||
|
||||
// Dimension ids.
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(offset + num_features), 1, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->input(offset), shape_rank_1, &shape_handle));
|
||||
|
||||
// Node ids.
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(offset + num_features * 2), 1, &shape_handle));
|
||||
auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
|
||||
auto shape_rank_2 =
|
||||
c->MakeShape({c->Dim(shape_handle, 0), logits_dimension});
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->input(offset), shape_rank_1, &shape_handle));
|
||||
|
||||
// Gains.
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(offset + num_features * 3), 1, &shape_handle));
|
||||
// TODO(nponomareva): replace this with input("name",vector of shapes).
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 3),
|
||||
shape_rank_1, &shape_handle));
|
||||
|
||||
@ -673,6 +678,8 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2")
|
||||
shape_rank_1, &shape_handle));
|
||||
|
||||
// Left and right node contribs.
|
||||
auto shape_rank_2 =
|
||||
c->MakeShape({c->Dim(shape_handle, 0), logits_dimension});
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(offset + num_features * 5), 2, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 5),
|
||||
|
@ -69,7 +69,6 @@ tf_py_test(
|
||||
name = "training_ops_test",
|
||||
size = "small",
|
||||
srcs = ["training_ops_test.py"],
|
||||
tags = ["noasan"], # b/148159528
|
||||
deps = [
|
||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -38,7 +38,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testGrowWithEmptyEnsemble(self):
|
||||
"""Test growing an empty ensemble."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
|
||||
tree_ensemble_handle = tree_ensemble.resource_handle
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
@ -148,7 +148,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testGrowWithEmptyEnsembleV2(self):
|
||||
"""Test growing an empty ensemble."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
|
||||
tree_ensemble_handle = tree_ensemble.resource_handle
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
@ -257,7 +257,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testGrowWithEmptyEnsembleV2EqualitySplit(self):
|
||||
"""Test growing an empty ensemble."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
|
||||
tree_ensemble_handle = tree_ensemble.resource_handle
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
@ -366,7 +366,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testGrowWithEmptyEnsembleV2MultiClass(self):
|
||||
"""Test growing an empty ensemble for multi-class case."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
|
||||
tree_ensemble_handle = tree_ensemble.resource_handle
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
@ -500,7 +500,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testBiasCenteringOnEmptyEnsemble(self):
|
||||
"""Test growing with bias centering on an empty ensemble."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
|
||||
tree_ensemble_handle = tree_ensemble.resource_handle
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
@ -2665,7 +2665,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testPostPruningOfSomeNodes(self):
|
||||
"""Test growing an ensemble with post-pruning."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
|
||||
@ -3000,7 +3000,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testPostPruningOfSomeNodesMultiClassV2(self):
|
||||
"""Test growing an ensemble with post-pruning."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
|
||||
@ -3454,8 +3454,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testPostPruningOfAllNodes(self):
|
||||
"""Test growing an ensemble with post-pruning, with all nodes are pruned."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
|
||||
@ -3638,7 +3637,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testPostPruningOfAllNodesMultiClassV2(self):
|
||||
"""Test growing an ensemble with post-pruning, with all nodes are pruned."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
|
||||
@ -3753,7 +3752,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Prepare inputs.
|
||||
# All have negative gain.
|
||||
group1_feature_ids = [3]
|
||||
group1_feature_ids = [3, 0]
|
||||
group1_nodes = np.array([1, 2], dtype=np.int32)
|
||||
group1_gains = np.array([-0.2, -0.5], dtype=np.float32)
|
||||
group1_dimensions = np.array([0, 4], dtype=np.int32)
|
||||
@ -3869,7 +3868,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testPostPruningChangesNothing(self):
|
||||
"""Test growing an ensemble with post-pruning with all gains >0."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
|
||||
@ -3970,7 +3969,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
def testPostPruningChangesNothingMultiClassV2(self):
|
||||
"""Test growing an ensemble with post-pruning with all gains >0."""
|
||||
with self.cached_session() as session:
|
||||
# Create empty ensemble.
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
|
||||
@ -4095,6 +4094,48 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(new_stamp, 1)
|
||||
self.assertProtoEquals(expected_result, res_ensemble)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMismatchedInputLength(self):
|
||||
"""Tests raises invalid argument error when input list lengths mismatch."""
|
||||
with self.cached_session() as session:
|
||||
# Create an empty ensemble.
|
||||
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||
tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
|
||||
tree_ensemble_handle = tree_ensemble.resource_handle
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
|
||||
# Prepare inputs.
|
||||
length_one_feature_ids = [3] # Should be length 2 to match others.
|
||||
nodes = np.array([1, 2], dtype=np.int32)
|
||||
gains = np.array([-0.2, -0.5], dtype=np.float32)
|
||||
dimensions = np.array([0, 4], dtype=np.int32)
|
||||
thresholds = np.array([77, 79], dtype=np.int32)
|
||||
left_node_contribs = np.array([[0.023, -0.99], [0.3, 5.979]],
|
||||
dtype=np.float32)
|
||||
right_node_contribs = np.array([[0.012343, 0.63], [24, 0.289]],
|
||||
dtype=np.float32)
|
||||
split_types = np.array(
|
||||
[_INEQUALITY_DEFAULT_LEFT, _INEQUALITY_DEFAULT_LEFT])
|
||||
with self.assertRaisesRegexp(Exception,
|
||||
r'Dimension 0 in both shapes must be equal'):
|
||||
grow_op = boosted_trees_ops.update_ensemble_v2(
|
||||
tree_ensemble_handle,
|
||||
learning_rate=1.0,
|
||||
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
||||
max_depth=2,
|
||||
feature_ids=[length_one_feature_ids],
|
||||
dimension_ids=[dimensions],
|
||||
node_ids=[nodes],
|
||||
gains=[gains],
|
||||
thresholds=[thresholds],
|
||||
left_node_contribs=[left_node_contribs],
|
||||
right_node_contribs=[right_node_contribs],
|
||||
split_types=[split_types],
|
||||
logits_dimension=2)
|
||||
|
||||
session.run(grow_op)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user