For the default implementation of TypeSpec.is_compatible, don't require that nested values in the left-hand-side and right-hand-side have identical types if one is a TypeSpee.

PiperOrigin-RevId: 327071494
Change-Id: I976df94c0d56bc3e8343dc873c5d1324aa69a150
This commit is contained in:
Edward Loper 2020-08-17 12:15:14 -07:00 committed by TensorFlower Gardener
parent 2d98952a90
commit ddc0620062
2 changed files with 70 additions and 2 deletions
tensorflow/python/framework

View File

@ -380,6 +380,8 @@ class TypeSpec(object):
@staticmethod
def __is_compatible(a, b):
"""Returns true if the given type serializations compatible."""
if isinstance(a, TypeSpec):
return a.is_compatible_with(b)
if type(a) is not type(b):
return False
if isinstance(a, (list, tuple)):
@ -388,7 +390,7 @@ class TypeSpec(object):
if isinstance(a, dict):
return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all(
TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys()))
if isinstance(a, (TypeSpec, tensor_shape.TensorShape, dtypes.DType)):
if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)):
return a.is_compatible_with(b)
return a == b

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import googletest
@ -67,7 +68,8 @@ class TwoTensorsSpec(type_spec.TypeSpec):
return (value.x, value.y)
def _from_components(self, components):
return TwoTensors(*components)
x, y = components
return TwoTensors(x, y, self.color)
def _serialize(self):
return (self.x_shape, self.x_dtype, self.y_shape, self.y_dtype, self.color)
@ -82,6 +84,54 @@ type_spec.register_type_spec_from_value_converter(
TwoTensors, TwoTensorsSpec.from_value)
class TwoComposites(object):
"""A simple value type to test TypeSpec.
Contains two composite tensorstensors (x, y) and a string (color).
"""
def __init__(self, x, y, color="red"):
assert isinstance(color, str)
self.x = ops.convert_to_tensor_or_composite(x)
self.y = ops.convert_to_tensor_or_composite(y)
self.color = color
class TwoCompositesSpec(type_spec.TypeSpec):
"""A TypeSpec for the TwoTensors value type."""
def __init__(self, x_spec, y_spec, color="red"):
self.x_spec = x_spec
self.y_spec = y_spec
self.color = color
value_type = property(lambda self: TwoComposites)
@property
def _component_specs(self):
return (self.x_spec, self.y_spec)
def _to_components(self, value):
return (value.x, value.y)
def _from_components(self, components):
x, y = components
return TwoTensors(x, y, self.color)
def _serialize(self):
return (self.x_spec, self.y_spec, self.color)
@classmethod
def from_value(cls, value):
return cls(type_spec.type_spec_from_value(value.x),
type_spec.type_spec_from_value(value.y),
value.color)
type_spec.register_type_spec_from_value_converter(
TwoComposites, TwoCompositesSpec.from_value)
class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.named_parameters(
@ -283,5 +333,21 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
spec = type_spec.type_spec_from_value(value)
self.assertEqual(spec, TwoTensorsSpec.from_value(value))
def testNestedRagged(self):
# Check that TwoCompositeSpecs are compatible if one has a nested
# RaggedTensorSpec w/ ragged_rank=0 and the other has a corresponding
# nested TensorSpec.
spec1 = TwoCompositesSpec(
ragged_tensor.RaggedTensorSpec([10], dtypes.int32, ragged_rank=0),
tensor_spec.TensorSpec(None, dtypes.int32))
spec2 = TwoCompositesSpec(
tensor_spec.TensorSpec([10], dtypes.int32),
tensor_spec.TensorSpec(None, dtypes.int32))
spec3 = TwoCompositesSpec(
tensor_spec.TensorSpec([12], dtypes.int32),
tensor_spec.TensorSpec(None, dtypes.int32))
self.assertTrue(spec1.is_compatible_with(spec2))
self.assertFalse(spec1.is_compatible_with(spec3))
if __name__ == "__main__":
googletest.main()