From ddc0620062f5ea3d7f45b600941671cea8ffa2ad Mon Sep 17 00:00:00 2001 From: Edward Loper <edloper@google.com> Date: Mon, 17 Aug 2020 12:15:14 -0700 Subject: [PATCH] For the default implementation of TypeSpec.is_compatible, don't require that nested values in the left-hand-side and right-hand-side have identical types if one is a TypeSpee. PiperOrigin-RevId: 327071494 Change-Id: I976df94c0d56bc3e8343dc873c5d1324aa69a150 --- tensorflow/python/framework/type_spec.py | 4 +- tensorflow/python/framework/type_spec_test.py | 68 ++++++++++++++++++- 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py index 4bf2ad791d7..ebfce25d6db 100644 --- a/tensorflow/python/framework/type_spec.py +++ b/tensorflow/python/framework/type_spec.py @@ -380,6 +380,8 @@ class TypeSpec(object): @staticmethod def __is_compatible(a, b): """Returns true if the given type serializations compatible.""" + if isinstance(a, TypeSpec): + return a.is_compatible_with(b) if type(a) is not type(b): return False if isinstance(a, (list, tuple)): @@ -388,7 +390,7 @@ class TypeSpec(object): if isinstance(a, dict): return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all( TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys())) - if isinstance(a, (TypeSpec, tensor_shape.TensorShape, dtypes.DType)): + if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)): return a.is_compatible_with(b) return a == b diff --git a/tensorflow/python/framework/type_spec_test.py b/tensorflow/python/framework/type_spec_test.py index 46e1ea32d72..bcffd43ee6a 100644 --- a/tensorflow/python/framework/type_spec_test.py +++ b/tensorflow/python/framework/type_spec_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.framework import type_spec +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import googletest @@ -67,7 +68,8 @@ class TwoTensorsSpec(type_spec.TypeSpec): return (value.x, value.y) def _from_components(self, components): - return TwoTensors(*components) + x, y = components + return TwoTensors(x, y, self.color) def _serialize(self): return (self.x_shape, self.x_dtype, self.y_shape, self.y_dtype, self.color) @@ -82,6 +84,54 @@ type_spec.register_type_spec_from_value_converter( TwoTensors, TwoTensorsSpec.from_value) +class TwoComposites(object): + """A simple value type to test TypeSpec. + + Contains two composite tensorstensors (x, y) and a string (color). + """ + + def __init__(self, x, y, color="red"): + assert isinstance(color, str) + self.x = ops.convert_to_tensor_or_composite(x) + self.y = ops.convert_to_tensor_or_composite(y) + self.color = color + + +class TwoCompositesSpec(type_spec.TypeSpec): + """A TypeSpec for the TwoTensors value type.""" + + def __init__(self, x_spec, y_spec, color="red"): + self.x_spec = x_spec + self.y_spec = y_spec + self.color = color + + value_type = property(lambda self: TwoComposites) + + @property + def _component_specs(self): + return (self.x_spec, self.y_spec) + + def _to_components(self, value): + return (value.x, value.y) + + def _from_components(self, components): + x, y = components + return TwoTensors(x, y, self.color) + + def _serialize(self): + return (self.x_spec, self.y_spec, self.color) + + @classmethod + def from_value(cls, value): + return cls(type_spec.type_spec_from_value(value.x), + type_spec.type_spec_from_value(value.y), + value.color) + + +type_spec.register_type_spec_from_value_converter( + TwoComposites, TwoCompositesSpec.from_value) + + class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase): @parameterized.named_parameters( @@ -283,5 +333,21 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase): spec = type_spec.type_spec_from_value(value) self.assertEqual(spec, TwoTensorsSpec.from_value(value)) + def testNestedRagged(self): + # Check that TwoCompositeSpecs are compatible if one has a nested + # RaggedTensorSpec w/ ragged_rank=0 and the other has a corresponding + # nested TensorSpec. + spec1 = TwoCompositesSpec( + ragged_tensor.RaggedTensorSpec([10], dtypes.int32, ragged_rank=0), + tensor_spec.TensorSpec(None, dtypes.int32)) + spec2 = TwoCompositesSpec( + tensor_spec.TensorSpec([10], dtypes.int32), + tensor_spec.TensorSpec(None, dtypes.int32)) + spec3 = TwoCompositesSpec( + tensor_spec.TensorSpec([12], dtypes.int32), + tensor_spec.TensorSpec(None, dtypes.int32)) + self.assertTrue(spec1.is_compatible_with(spec2)) + self.assertFalse(spec1.is_compatible_with(spec3)) + if __name__ == "__main__": googletest.main()