For the default implementation of TypeSpec.is_compatible, don't require that nested values in the left-hand-side and right-hand-side have identical types if one is a TypeSpee.
PiperOrigin-RevId: 327071494 Change-Id: I976df94c0d56bc3e8343dc873c5d1324aa69a150
This commit is contained in:
parent
2d98952a90
commit
ddc0620062
tensorflow/python/framework
@ -380,6 +380,8 @@ class TypeSpec(object):
|
||||
@staticmethod
|
||||
def __is_compatible(a, b):
|
||||
"""Returns true if the given type serializations compatible."""
|
||||
if isinstance(a, TypeSpec):
|
||||
return a.is_compatible_with(b)
|
||||
if type(a) is not type(b):
|
||||
return False
|
||||
if isinstance(a, (list, tuple)):
|
||||
@ -388,7 +390,7 @@ class TypeSpec(object):
|
||||
if isinstance(a, dict):
|
||||
return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all(
|
||||
TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys()))
|
||||
if isinstance(a, (TypeSpec, tensor_shape.TensorShape, dtypes.DType)):
|
||||
if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)):
|
||||
return a.is_compatible_with(b)
|
||||
return a == b
|
||||
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@ -67,7 +68,8 @@ class TwoTensorsSpec(type_spec.TypeSpec):
|
||||
return (value.x, value.y)
|
||||
|
||||
def _from_components(self, components):
|
||||
return TwoTensors(*components)
|
||||
x, y = components
|
||||
return TwoTensors(x, y, self.color)
|
||||
|
||||
def _serialize(self):
|
||||
return (self.x_shape, self.x_dtype, self.y_shape, self.y_dtype, self.color)
|
||||
@ -82,6 +84,54 @@ type_spec.register_type_spec_from_value_converter(
|
||||
TwoTensors, TwoTensorsSpec.from_value)
|
||||
|
||||
|
||||
class TwoComposites(object):
|
||||
"""A simple value type to test TypeSpec.
|
||||
|
||||
Contains two composite tensorstensors (x, y) and a string (color).
|
||||
"""
|
||||
|
||||
def __init__(self, x, y, color="red"):
|
||||
assert isinstance(color, str)
|
||||
self.x = ops.convert_to_tensor_or_composite(x)
|
||||
self.y = ops.convert_to_tensor_or_composite(y)
|
||||
self.color = color
|
||||
|
||||
|
||||
class TwoCompositesSpec(type_spec.TypeSpec):
|
||||
"""A TypeSpec for the TwoTensors value type."""
|
||||
|
||||
def __init__(self, x_spec, y_spec, color="red"):
|
||||
self.x_spec = x_spec
|
||||
self.y_spec = y_spec
|
||||
self.color = color
|
||||
|
||||
value_type = property(lambda self: TwoComposites)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
return (self.x_spec, self.y_spec)
|
||||
|
||||
def _to_components(self, value):
|
||||
return (value.x, value.y)
|
||||
|
||||
def _from_components(self, components):
|
||||
x, y = components
|
||||
return TwoTensors(x, y, self.color)
|
||||
|
||||
def _serialize(self):
|
||||
return (self.x_spec, self.y_spec, self.color)
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, value):
|
||||
return cls(type_spec.type_spec_from_value(value.x),
|
||||
type_spec.type_spec_from_value(value.y),
|
||||
value.color)
|
||||
|
||||
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
TwoComposites, TwoCompositesSpec.from_value)
|
||||
|
||||
|
||||
class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -283,5 +333,21 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
spec = type_spec.type_spec_from_value(value)
|
||||
self.assertEqual(spec, TwoTensorsSpec.from_value(value))
|
||||
|
||||
def testNestedRagged(self):
|
||||
# Check that TwoCompositeSpecs are compatible if one has a nested
|
||||
# RaggedTensorSpec w/ ragged_rank=0 and the other has a corresponding
|
||||
# nested TensorSpec.
|
||||
spec1 = TwoCompositesSpec(
|
||||
ragged_tensor.RaggedTensorSpec([10], dtypes.int32, ragged_rank=0),
|
||||
tensor_spec.TensorSpec(None, dtypes.int32))
|
||||
spec2 = TwoCompositesSpec(
|
||||
tensor_spec.TensorSpec([10], dtypes.int32),
|
||||
tensor_spec.TensorSpec(None, dtypes.int32))
|
||||
spec3 = TwoCompositesSpec(
|
||||
tensor_spec.TensorSpec([12], dtypes.int32),
|
||||
tensor_spec.TensorSpec(None, dtypes.int32))
|
||||
self.assertTrue(spec1.is_compatible_with(spec2))
|
||||
self.assertFalse(spec1.is_compatible_with(spec3))
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user