Fix failing ASan test due to out of boundary memory access

PiperOrigin-RevId: 292701009
Change-Id: I2f964c9958f0b2863071707a463f03b3f4e16c11
This commit is contained in:
A. Unique TensorFlower 2020-02-01 05:44:44 -08:00 committed by TensorFlower Gardener
parent d15e205f9e
commit ea437ce27c
3 changed files with 65 additions and 18 deletions

View File

@ -645,24 +645,29 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2")
shape_inference::ShapeHandle shape_handle;
for (int i = 0; i < num_groups; ++i) {
int offset = i + 1;
// Feature ids
TF_RETURN_IF_ERROR(c->WithRank(c->input(offset), 1, &shape_handle));
// TODO(nponomareva): replace this with input("name",vector of shapes).
auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
TF_RETURN_IF_ERROR(
c->Merge(c->input(offset), shape_rank_1, &shape_handle));
// Dimension ids.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(offset + num_features), 1, &shape_handle));
TF_RETURN_IF_ERROR(
c->Merge(c->input(offset), shape_rank_1, &shape_handle));
// Node ids.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(offset + num_features * 2), 1, &shape_handle));
auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
auto shape_rank_2 =
c->MakeShape({c->Dim(shape_handle, 0), logits_dimension});
TF_RETURN_IF_ERROR(
c->Merge(c->input(offset), shape_rank_1, &shape_handle));
// Gains.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(offset + num_features * 3), 1, &shape_handle));
// TODO(nponomareva): replace this with input("name",vector of shapes).
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 3),
shape_rank_1, &shape_handle));
@ -673,6 +678,8 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2")
shape_rank_1, &shape_handle));
// Left and right node contribs.
auto shape_rank_2 =
c->MakeShape({c->Dim(shape_handle, 0), logits_dimension});
TF_RETURN_IF_ERROR(
c->WithRank(c->input(offset + num_features * 5), 2, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 5),

View File

@ -69,7 +69,6 @@ tf_py_test(
name = "training_ops_test",
size = "small",
srcs = ["training_ops_test.py"],
tags = ["noasan"], # b/148159528
deps = [
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
"//tensorflow/python:array_ops",

View File

@ -38,7 +38,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowWithEmptyEnsemble(self):
"""Test growing an empty ensemble."""
with self.cached_session() as session:
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
resources.initialize_resources(resources.shared_resources()).run()
@ -148,7 +148,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowWithEmptyEnsembleV2(self):
"""Test growing an empty ensemble."""
with self.cached_session() as session:
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
resources.initialize_resources(resources.shared_resources()).run()
@ -257,7 +257,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowWithEmptyEnsembleV2EqualitySplit(self):
"""Test growing an empty ensemble."""
with self.cached_session() as session:
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
resources.initialize_resources(resources.shared_resources()).run()
@ -366,7 +366,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowWithEmptyEnsembleV2MultiClass(self):
"""Test growing an empty ensemble for multi-class case."""
with self.cached_session() as session:
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
resources.initialize_resources(resources.shared_resources()).run()
@ -500,7 +500,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testBiasCenteringOnEmptyEnsemble(self):
"""Test growing with bias centering on an empty ensemble."""
with self.cached_session() as session:
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
resources.initialize_resources(resources.shared_resources()).run()
@ -2665,7 +2665,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningOfSomeNodes(self):
"""Test growing an ensemble with post-pruning."""
with self.cached_session() as session:
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
@ -3000,7 +3000,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningOfSomeNodesMultiClassV2(self):
"""Test growing an ensemble with post-pruning."""
with self.cached_session() as session:
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
@ -3454,8 +3454,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningOfAllNodes(self):
"""Test growing an ensemble with post-pruning, with all nodes are pruned."""
with self.cached_session() as session:
# Create empty ensemble.
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
@ -3638,7 +3637,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningOfAllNodesMultiClassV2(self):
"""Test growing an ensemble with post-pruning, with all nodes are pruned."""
with self.cached_session() as session:
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
@ -3753,7 +3752,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
# Prepare inputs.
# All have negative gain.
group1_feature_ids = [3]
group1_feature_ids = [3, 0]
group1_nodes = np.array([1, 2], dtype=np.int32)
group1_gains = np.array([-0.2, -0.5], dtype=np.float32)
group1_dimensions = np.array([0, 4], dtype=np.int32)
@ -3869,7 +3868,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningChangesNothing(self):
"""Test growing an ensemble with post-pruning with all gains >0."""
with self.cached_session() as session:
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
@ -3970,7 +3969,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningChangesNothingMultiClassV2(self):
"""Test growing an ensemble with post-pruning with all gains >0."""
with self.cached_session() as session:
# Create empty ensemble.
# Create an empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
@ -4095,6 +4094,48 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
self.assertEqual(new_stamp, 1)
self.assertProtoEquals(expected_result, res_ensemble)
@test_util.run_deprecated_v1
def testMismatchedInputLength(self):
"""Tests raises invalid argument error when input list lengths mismatch."""
with self.cached_session() as session:
# Create an empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
tree_ensemble_handle = tree_ensemble.resource_handle
resources.initialize_resources(resources.shared_resources()).run()
# Prepare inputs.
length_one_feature_ids = [3] # Should be length 2 to match others.
nodes = np.array([1, 2], dtype=np.int32)
gains = np.array([-0.2, -0.5], dtype=np.float32)
dimensions = np.array([0, 4], dtype=np.int32)
thresholds = np.array([77, 79], dtype=np.int32)
left_node_contribs = np.array([[0.023, -0.99], [0.3, 5.979]],
dtype=np.float32)
right_node_contribs = np.array([[0.012343, 0.63], [24, 0.289]],
dtype=np.float32)
split_types = np.array(
[_INEQUALITY_DEFAULT_LEFT, _INEQUALITY_DEFAULT_LEFT])
with self.assertRaisesRegexp(Exception,
r'Dimension 0 in both shapes must be equal'):
grow_op = boosted_trees_ops.update_ensemble_v2(
tree_ensemble_handle,
learning_rate=1.0,
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
max_depth=2,
feature_ids=[length_one_feature_ids],
dimension_ids=[dimensions],
node_ids=[nodes],
gains=[gains],
thresholds=[thresholds],
left_node_contribs=[left_node_contribs],
right_node_contribs=[right_node_contribs],
split_types=[split_types],
logits_dimension=2)
session.run(grow_op)
if __name__ == '__main__':
googletest.main()