[nest] assert_shallow_structure now allows shallow to be a _ListWrapper/list subclass.

Bugfix.  Behavior is now in line with assert_same_structure.

This is inline with allowing the shallow structure to be any matching wrapt
wrapped object, as DictWrapper/TupleWrappers are, and in line with allowing
namedtuples having different key names.

PiperOrigin-RevId: 351431844
Change-Id: I990b7aee30bd9ccbdcdada0180d453d99d2e3013
This commit is contained in:
Eugene Brevdo 2021-01-12 13:06:14 -08:00 committed by TensorFlower Gardener
parent 637a24786c
commit 6b926e0a6b
2 changed files with 16 additions and 0 deletions

View File

@ -1067,6 +1067,11 @@ def assert_shallow_structure(shallow_tree,
input_type=type(input_tree),
shallow_type=type(shallow_tree)))
elif isinstance(shallow_tree, list) and isinstance(input_tree, list):
# List subclasses are considered the same,
# e.g. python list vs. _ListWrapper.
pass
elif ((_is_composite_tensor(shallow_tree) or
_is_composite_tensor(input_tree)) and
(_is_type_spec(shallow_tree) or _is_type_spec(input_tree))):

View File

@ -56,6 +56,10 @@ class _CustomMapping(collections_abc.Mapping):
return len(self._wrapped)
class _CustomList(list):
pass
class _CustomSequenceThatRaisesException(collections.Sequence):
def __len__(self):
@ -606,6 +610,13 @@ class NestTest(parameterized.TestCase, test.TestCase):
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
# This assertion is expected to pass: two list-types with same number
# of fields are considered identical.
inp_shallow = _CustomList([1, 2])
inp_deep = [1, 2]
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
def testFlattenUpTo(self):
# Shallow tree ends at scalar.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]