Merge pull request #42397 from yongtang:42329-tf.nest.assert_same_structure-crash

PiperOrigin-RevId: 327541695
Change-Id: I29ed66c2a161b10e4aa44772bc2b167b74a73cd5
This commit is contained in:
TensorFlower Gardener 2020-08-19 17:46:17 -07:00
commit 313edafd6f
2 changed files with 16 additions and 0 deletions

View File

@ -392,6 +392,10 @@ def assert_same_structure(nest1, nest2, check_types=True,
TypeError: If the two structures differ in the type of sequence in any of
their substructures. Only possible if `check_types` is `True`.
"""
# Convert to bool explicitly as otherwise pybind will not be able# to handle
# type mismatch message correctly. See GitHub issue 42329 for details.
check_types = bool(check_types)
expand_composites = bool(expand_composites)
try:
_pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
expand_composites)

View File

@ -1218,6 +1218,18 @@ class NestTest(parameterized.TestCase, test.TestCase):
expected,
)
def testInvalidCheckTypes(self):
with self.assertRaises((ValueError, TypeError)):
nest.assert_same_structure(
nest1=array_ops.zeros((1)),
nest2=array_ops.ones((1, 1, 1)),
check_types=array_ops.ones((2)))
with self.assertRaises((ValueError, TypeError)):
nest.assert_same_structure(
nest1=array_ops.zeros((1)),
nest2=array_ops.ones((1, 1, 1)),
expand_composites=array_ops.ones((2)))
class NestBenchmark(test.Benchmark):