diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index 170fda90b68..b6c1376b6ad 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -165,7 +165,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/util:structure", ], ) diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py index 80ca7104d85..db10ea3b7f8 100644 --- a/tensorflow/python/data/experimental/ops/grouping.py +++ b/tensorflow/python/data/experimental/ops/grouping.py @@ -21,6 +21,7 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -448,7 +449,10 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset): def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" - nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access + nested_dataset = dataset_ops.DatasetStructure( + structure.Structure._from_legacy_structure( # pylint: disable=protected-access + input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes)) wrapped_func = dataset_ops.StructuredFunctionWrapper( reduce_func, self._transformation_name(), @@ -456,11 +460,13 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset): input_shapes=(tensor_shape.scalar(), nested_dataset), input_types=(dtypes.int64, nested_dataset)) if not isinstance( - wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access + wrapped_func.output_structure, dataset_ops.DatasetStructure): raise TypeError("`reduce_func` must return a `Dataset` object.") - self._output_classes = wrapped_func.output_classes.output_classes - self._output_types = wrapped_func.output_types.output_types - self._output_shapes = wrapped_func.output_shapes.output_shapes + # pylint: disable=protected-access + element_structure = wrapped_func.output_structure._element_structure + self._output_classes = element_structure._to_legacy_output_classes() + self._output_types = element_structure._to_legacy_output_types() + self._output_shapes = element_structure._to_legacy_output_shapes() self._reduce_func = wrapped_func.function @property diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 21eed2b070a..fa1f6d701a4 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -117,8 +117,12 @@ tf_py_test( "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:optional_ops", + "//tensorflow/python/data/util:structure", ], ) diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py index a5324af4d0c..1f22a37c2e0 100644 --- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py @@ -24,10 +24,14 @@ import numpy as np from tensorflow.core.framework import graph_pb2 from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import optional_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test @@ -249,6 +253,63 @@ class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertTrue(ds.options().experimental_autotune) self.assertTrue(ds.options().experimental_filter_fusion) + # pylint: disable=g-long-lambda + @parameterized.named_parameters( + ("Tensor", lambda: constant_op.constant(37.0), + structure.TensorStructure(dtypes.float32, [])), + ("SparseTensor", lambda: sparse_tensor.SparseTensor( + indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), + dense_shape=[1]), + structure.SparseTensorStructure(dtypes.int32, [1])), + ("Nest", lambda: { + "a": constant_op.constant(37.0), + "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, + structure.NestedStructure({ + "a": structure.TensorStructure(dtypes.float32, []), + "b": (structure.TensorStructure(dtypes.string, [1]), + structure.TensorStructure(dtypes.string, []))})), + ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( + constant_op.constant([1, 2, 3])), + dataset_ops.DatasetStructure( + structure.TensorStructure(dtypes.int32, []))), + ("Optional", lambda: optional_ops.Optional.from_value(37.0), + optional_ops.OptionalStructure( + structure.TensorStructure(dtypes.float32, []))), + ) + def testDatasetStructure(self, tf_value_fn, expected_element_structure): + dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn()) + dataset_structure = structure.Structure.from_value(dataset) + self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) + + # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing + # the element structure. + self.assertTrue(expected_element_structure.is_compatible_with( + dataset_structure._element_structure)) + self.assertTrue(dataset_structure._element_structure.is_compatible_with( + expected_element_structure)) + + self.assertEqual([dtypes.variant], dataset_structure._flat_types) + self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) + + # Assert that the `Dataset` survives a round-trip via _from_tensor_list() + # and _to_tensor_list(). + round_trip_dataset = dataset_structure._from_tensor_list( + dataset_structure._to_tensor_list(dataset)) + + value = tf_value_fn() + + if isinstance(value, dataset_ops.Dataset): + self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) + elif isinstance(value, optional_ops.Optional): + self.assertDatasetProduces( + round_trip_dataset.map(lambda opt: opt.get_value()), + [self.evaluate(value.get_value())], + requires_initialization=True) + else: + self.assertDatasetProduces( + round_trip_dataset, [self.evaluate(tf_value_fn())], + requires_initialization=True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 4a11619112b..5c0cfe994d9 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse +from tensorflow.python.data.util import structure as structure_lib from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -1868,57 +1869,6 @@ class SparseTensorSliceDataset(DatasetSource): return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64) -class _NestedDatasetComponent(object): - """The structure of a `Dataset` nested in a component of another `Dataset`. - - A `StructuredFunctionWrapper` around a function that returns a `Dataset` as - one of its components will have a `NestedDatasetComponent` in the - corresponding position in the `output_classes`, `output_shapes`, and - `output_types` properties. - - TODO(b/110122868): Add this class, or something equivalent, to the public API. - We are considering revising the public API for accessing Dataset structure - (`output_classes` etc.) based on experience with nested datasets and other - custom component types. - """ - - def __init__(self, - dataset=None, - output_shapes=None, - output_types=None, - output_classes=None): - if dataset is None: - if (output_classes is None or output_shapes is None or - output_types is None): - raise ValueError( - "Either `dataset`, or all of `output_classes`, " - "`output_shapes`, and `output_types` must be specified.") - self._output_classes = output_classes - self._output_shapes = output_shapes - self._output_types = output_types - else: - if not (output_classes is None and output_shapes is None and - output_types is None): - raise ValueError( - "Either `dataset`, or all of `output_classes`, " - "`output_shapes`, and `output_types` must be specified.") - self._output_classes = dataset.output_classes - self._output_shapes = dataset.output_shapes - self._output_types = dataset.output_types - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - class _VariantDataset(DatasetV2): """A Dataset wrapper around a `tf.variant`-typed function argument.""" @@ -1935,15 +1885,73 @@ class _VariantDataset(DatasetV2): @property def output_classes(self): - return self._structure.output_classes + return self._structure._to_legacy_output_classes() # pylint: disable=protected-access @property def output_shapes(self): - return self._structure.output_shapes + return self._structure._to_legacy_output_shapes() # pylint: disable=protected-access @property def output_types(self): - return self._structure.output_types + return self._structure._to_legacy_output_types() # pylint: disable=protected-access + + +class DatasetStructure(structure_lib.Structure): + """Represents a `Dataset` of structured values.""" + + def __init__(self, element_structure): + self._element_structure = element_structure + + @property + def _flat_shapes(self): + return [tensor_shape.scalar()] + + @property + def _flat_types(self): + return [dtypes.variant] + + def is_compatible_with(self, other): + # pylint: disable=protected-access + return (isinstance(other, DatasetStructure) and + self._element_structure.is_compatible_with( + other._element_structure)) + + def _to_tensor_list(self, value): + return [value._as_variant_tensor()] # pylint: disable=protected-access + + def _from_tensor_list(self, flat_value): + if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or + not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "DatasetStructure corresponds to a single tf.variant scalar.") + return self._from_compatible_tensor_list(flat_value) + + def _from_compatible_tensor_list(self, flat_value): + # pylint: disable=protected-access + return _VariantDataset(flat_value[0], self._element_structure) + + @staticmethod + def from_value(value): + # TODO(b/110122868): We can simplify this when a `Dataset` object has a + # `Structure`-valued property. + element_structure = structure_lib.Structure._from_legacy_structure( + value.output_types, value.output_shapes, value.output_classes) + return DatasetStructure(element_structure) + + def _to_legacy_output_types(self): + return self + + def _to_legacy_output_shapes(self): + return self + + def _to_legacy_output_classes(self): + return self + + +# pylint: disable=protected-access +structure_lib.Structure._register_custom_converter(DatasetV2, + DatasetStructure.from_value) +# pylint: enable=protected-access class StructuredFunctionWrapper(object): @@ -2001,6 +2009,9 @@ class StructuredFunctionWrapper(object): self._input_types = dataset.output_types self._input_classes = dataset.output_classes + self._input_structure = structure_lib.Structure._from_legacy_structure( # pylint: disable=protected-access + self._input_types, self._input_shapes, self._input_classes) + self._transformation_name = transformation_name readable_transformation_name = transformation_name.replace( ".", "_")[:-2] if len(transformation_name) > 2 else "" @@ -2008,35 +2019,18 @@ class StructuredFunctionWrapper(object): readable_transformation_name, function_utils.get_func_name(func), str(ops.uid()) - ]) if defun_kwargs is None: defun_kwargs = {} @function.Defun( - *self._defun_args(), func_name=self._func_name, **defun_kwargs) + *self._input_structure._flat_types, func_name=self._func_name, # pylint: disable=protected-access + **defun_kwargs) def tf_data_structured_function_wrapper(*args): """Wrapper for passing nested structures to and from tf.data functions.""" - flat_args = [] - for arg, arg_class, arg_shape, arg_type in zip( - args, - nest.flatten(self._input_classes), - nest.flatten(self._input_shapes), - nest.flatten(self._input_types)): - # TODO(b/110122868): Add a registration mechanism for new component - # types. - if arg_class is sparse_tensor_lib.SparseTensor: - arg = sparse.deserialize_sparse_tensors( - arg, arg_type, arg_shape, arg_class) - arg.indices.set_shape([None, arg_shape.ndims]) - arg.dense_shape.set_shape([arg_shape.ndims]) - elif isinstance(arg_class, _NestedDatasetComponent): - arg = _VariantDataset(arg, arg_class) - else: - arg.set_shape(arg_shape) - flat_args.append(arg) - nested_args = nest.pack_sequence_as(self._input_classes, flat_args) + # pylint: disable=protected-access + nested_args = self._input_structure._from_compatible_tensor_list(args) if not _should_unpack_args(nested_args): nested_args = (nested_args,) @@ -2054,50 +2048,14 @@ class StructuredFunctionWrapper(object): if isinstance(ret, list): ret = tuple(ret) - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - flat_ret = [] - flat_classes = [] - flat_shapes = [] - flat_types = [] - for t in nest.flatten(ret): - # TODO(b/110122868): Add a registration mechanism for new component - # types. - if sparse_tensor_lib.is_sparse(t): - t = sparse_tensor_lib.SparseTensor.from_value(t) - flat_ret.append(sparse.serialize_sparse_tensors(t)) - flat_classes.append(sparse_tensor_lib.SparseTensor) - flat_shapes.append(t.get_shape()) - flat_types.append(t.dtype) - elif isinstance(t, DatasetV2): - flat_ret.append(t._as_variant_tensor()) # pylint: disable=protected-access - component = _NestedDatasetComponent(t) - flat_classes.append(component) - flat_shapes.append(component) - flat_types.append(component) - if t.options() != Options(): - warnings.warn("Encountered a nested dataset with non-default " - "options. These options will not be propagated to " - "the outer dataset.") - else: - try: - t = ops.convert_to_tensor(t) - except (ValueError, TypeError): - raise TypeError("Unsupported return value from function passed to " - "%s: %s." % (transformation_name, t)) - flat_ret.append(t) - flat_classes.append(ops.Tensor) - flat_shapes.append(t.get_shape()) - flat_types.append(t.dtype) - - ret = nest.pack_sequence_as(ret, flat_ret) - self._output_classes = nest.pack_sequence_as(ret, flat_classes) - self._output_shapes = nest.pack_sequence_as(ret, flat_shapes) - self._output_types = nest.pack_sequence_as(ret, flat_types) + try: + self._output_structure = structure_lib.Structure.from_value(ret) + except (ValueError, TypeError): + raise TypeError("Unsupported return value from function passed to " + "%s: %s." % (transformation_name, ret)) _warn_if_collections(transformation_name) - - return flat_ret + return self._output_structure._to_tensor_list(ret) self._function = tf_data_structured_function_wrapper if add_to_graph: @@ -2108,32 +2066,21 @@ class StructuredFunctionWrapper(object): # in case (e.g.) we need to rerun the function. self._function._create_definition_if_needed() # pylint: disable=protected-access - def _defun_args(self): - """Returns a flat list of `tf.DType` for the input element structure.""" - ret = [] - for input_type, input_class in zip(nest.flatten(self._input_types), - nest.flatten(self._input_classes)): - # TODO(b/110122868): Add a registration mechanism for new component types. - if input_class is sparse_tensor_lib.SparseTensor: - ret.append(dtypes.variant) - elif isinstance(input_class, _NestedDatasetComponent): - ret.append(dtypes.variant) - else: - assert isinstance(input_type, dtypes.DType) - ret.append(input_type) - return ret + @property + def output_structure(self): + return self._output_structure @property def output_classes(self): - return self._output_classes + return self._output_structure._to_legacy_output_classes() # pylint: disable=protected-access @property def output_shapes(self): - return self._output_shapes + return self._output_structure._to_legacy_output_shapes() # pylint: disable=protected-access @property def output_types(self): - return self._output_types + return self._output_structure._to_legacy_output_types() # pylint: disable=protected-access @property def function(self): @@ -2156,30 +2103,12 @@ def flat_structure(dataset): A dictionary of keyword arguments that can be passed to many Dataset op constructors. """ - output_classes = [] - output_shapes = [] - output_types = [] - for output_class, output_shape, output_type in zip( - nest.flatten(dataset.output_classes), nest.flatten(dataset.output_shapes), - nest.flatten(dataset.output_types)): - if isinstance(output_class, _NestedDatasetComponent): - output_classes.append(output_class.output_classes) - output_shapes.append(output_shape.output_shapes) - output_types.append(output_type.output_types) - else: - output_classes.append(output_class) - output_shapes.append(output_shape) - output_types.append(output_type) - - output_classes = nest.pack_sequence_as(dataset.output_classes, output_classes) - output_shapes = nest.pack_sequence_as(dataset.output_shapes, output_shapes) - output_types = nest.pack_sequence_as(dataset.output_types, output_types) - + # pylint: disable=protected-access + structure = structure_lib.Structure._from_legacy_structure( + dataset.output_types, dataset.output_shapes, dataset.output_classes) return { - "output_shapes": - nest.flatten(sparse.as_dense_shapes(output_shapes, output_classes)), - "output_types": - nest.flatten(sparse.as_dense_types(output_types, output_classes)), + "output_shapes": structure._flat_shapes, + "output_types": structure._flat_types, } @@ -2902,11 +2831,13 @@ class FlatMapDataset(UnaryDataset): wrapped_func = StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset) - if not isinstance(wrapped_func.output_classes, _NestedDatasetComponent): + if not isinstance(wrapped_func.output_structure, DatasetStructure): raise TypeError("`map_func` must return a `Dataset` object.") - self._output_classes = wrapped_func.output_classes.output_classes - self._output_types = wrapped_func.output_types.output_types - self._output_shapes = wrapped_func.output_shapes.output_shapes + # pylint: disable=protected-access + element_structure = wrapped_func.output_structure._element_structure + self._output_classes = element_structure._to_legacy_output_classes() + self._output_types = element_structure._to_legacy_output_types() + self._output_shapes = element_structure._to_legacy_output_shapes() self._map_func = wrapped_func.function def _as_variant_tensor(self): @@ -3048,10 +2979,9 @@ class WindowDataset(UnaryDataset): self._output_classes = nest.pack_sequence_as( input_dataset.output_classes, [ - _NestedDatasetComponent( # pylint: disable=protected-access - output_classes=output_class, - output_shapes=output_shape, - output_types=output_type) + DatasetStructure( + structure_lib.Structure._from_legacy_structure( # pylint: disable=protected-access + output_type, output_shape, output_class)) for output_class, output_shape, output_type in zip( nest.flatten(input_dataset.output_classes), nest.flatten(input_dataset.output_shapes), diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py index 91cf883ce94..4113b7ed315 100644 --- a/tensorflow/python/data/ops/optional_ops.py +++ b/tensorflow/python/data/ops/optional_ops.py @@ -183,19 +183,13 @@ class OptionalStructure(structure.Structure): return OptionalStructure(value.value_structure) def _to_legacy_output_types(self): - raise NotImplementedError("The `output_types` property is not supported on " - "structured objects containing an `Optional`. " - "Use the corresponding `structure` property.") + return self def _to_legacy_output_shapes(self): - raise NotImplementedError("The `output_shapes` property is not supported on" - " structured objects containing an `Optional`. " - "Use the corresponding `structure` property.") + return self def _to_legacy_output_classes(self): - raise NotImplementedError("The `output_classes` property is not supported " - "on structured objects containing an `Optional`. " - "Use the corresponding `structure` property.") + return self # pylint: disable=protected-access diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 9a3118297db..3cf67b07453 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -208,14 +208,16 @@ class Structure(object): flat_ret = [] for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes, flat_classes): - if issubclass(flat_class, sparse_tensor_lib.SparseTensor): + if isinstance(flat_class, Structure): + flat_ret.append(flat_class) + elif issubclass(flat_class, sparse_tensor_lib.SparseTensor): flat_ret.append(SparseTensorStructure(flat_type, flat_shape)) elif issubclass(flat_class, ops.Tensor): flat_ret.append(TensorStructure(flat_type, flat_shape)) else: # NOTE(mrry): Since legacy structures produced by iterators only - # comprise Tensors, SparseTensors, and nests, we do not need to support - # all structure types here. + # comprise Tensors, SparseTensors, and nests, we do not need to + # support all structure types here. raise TypeError( "Could not build a structure for output class %r" % flat_type) @@ -381,6 +383,13 @@ class TensorStructure(Structure): return self._from_compatible_tensor_list(flat_value) def _from_compatible_tensor_list(self, flat_value): + # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()` + # op here and return that, instead of mutating the input's shape using + # `Tensor.set_shape()`. However, that would add extra ops on the arguments + # of each `tf.data` function, which could impact performance. When this + # bug is resolved, we should be able to add the `ensure_shape()` ops and + # optimize them away using contextual shape information. + flat_value[0].set_shape(self._shape) return flat_value[0] @staticmethod @@ -406,7 +415,11 @@ class SparseTensorStructure(Structure): @property def _flat_shapes(self): - return [tensor_shape.vector(3)] + # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`, + # but a `SparseTensorStructure` can also represent a batch of boxed + # `SparseTensor` objects with shape `(?, 3)` (and batches of batches, etc.), + # so the flat shape must be unknown. + return [tensor_shape.unknown_shape(None)] @property def _flat_types(self): @@ -428,8 +441,11 @@ class SparseTensorStructure(Structure): return self._from_compatible_tensor_list(flat_value) def _from_compatible_tensor_list(self, flat_value): - return sparse_ops.deserialize_sparse( + ret = sparse_ops.deserialize_sparse( flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims) + ret.indices.set_shape([None, self._dense_shape.ndims]) + ret.dense_shape.set_shape([self._dense_shape.ndims]) + return ret @staticmethod def from_value(value): diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py index 630a0c912bc..65a41a50f11 100644 --- a/tensorflow/python/data/util/structure_test.py +++ b/tensorflow/python/data/util/structure_test.py @@ -44,7 +44,7 @@ class StructureTest(test.TestCase, parameterized.TestCase): [dtypes.float32], [[]]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), - structure.SparseTensorStructure, [dtypes.variant], [[3]]), + structure.SparseTensorStructure, [dtypes.variant], [None]), (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), (lambda: { @@ -58,14 +58,17 @@ class StructureTest(test.TestCase, parameterized.TestCase): sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, structure.NestedStructure, - [dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]])) + [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None])) def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): value = value_fn() s = structure.Structure.from_value(value) self.assertIsInstance(s, expected_structure) self.assertEqual(expected_types, s._flat_types) - self.assertEqual(expected_shapes, s._flat_shapes) + for expected, actual in zip(expected_shapes, s._flat_shapes): + self.assertTrue(actual.is_compatible_with(expected)) + self.assertTrue( + tensor_shape.as_shape(expected).is_compatible_with(actual)) @parameterized.parameters( (lambda: constant_op.constant(37.0), lambda: [