From ddc0620062f5ea3d7f45b600941671cea8ffa2ad Mon Sep 17 00:00:00 2001
From: Edward Loper <edloper@google.com>
Date: Mon, 17 Aug 2020 12:15:14 -0700
Subject: [PATCH] For the default implementation of TypeSpec.is_compatible,
 don't require that nested values in the left-hand-side and right-hand-side
 have identical types if one is a TypeSpee.

PiperOrigin-RevId: 327071494
Change-Id: I976df94c0d56bc3e8343dc873c5d1324aa69a150
---
 tensorflow/python/framework/type_spec.py      |  4 +-
 tensorflow/python/framework/type_spec_test.py | 68 ++++++++++++++++++-
 2 files changed, 70 insertions(+), 2 deletions(-)

diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py
index 4bf2ad791d7..ebfce25d6db 100644
--- a/tensorflow/python/framework/type_spec.py
+++ b/tensorflow/python/framework/type_spec.py
@@ -380,6 +380,8 @@ class TypeSpec(object):
   @staticmethod
   def __is_compatible(a, b):
     """Returns true if the given type serializations compatible."""
+    if isinstance(a, TypeSpec):
+      return a.is_compatible_with(b)
     if type(a) is not type(b):
       return False
     if isinstance(a, (list, tuple)):
@@ -388,7 +390,7 @@ class TypeSpec(object):
     if isinstance(a, dict):
       return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all(
           TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys()))
-    if isinstance(a, (TypeSpec, tensor_shape.TensorShape, dtypes.DType)):
+    if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)):
       return a.is_compatible_with(b)
     return a == b
 
diff --git a/tensorflow/python/framework/type_spec_test.py b/tensorflow/python/framework/type_spec_test.py
index 46e1ea32d72..bcffd43ee6a 100644
--- a/tensorflow/python/framework/type_spec_test.py
+++ b/tensorflow/python/framework/type_spec_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
 from tensorflow.python.framework import type_spec
+from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.platform import googletest
 
 
@@ -67,7 +68,8 @@ class TwoTensorsSpec(type_spec.TypeSpec):
     return (value.x, value.y)
 
   def _from_components(self, components):
-    return TwoTensors(*components)
+    x, y = components
+    return TwoTensors(x, y, self.color)
 
   def _serialize(self):
     return (self.x_shape, self.x_dtype, self.y_shape, self.y_dtype, self.color)
@@ -82,6 +84,54 @@ type_spec.register_type_spec_from_value_converter(
     TwoTensors, TwoTensorsSpec.from_value)
 
 
+class TwoComposites(object):
+  """A simple value type to test TypeSpec.
+
+  Contains two composite tensorstensors (x, y) and a string (color).
+  """
+
+  def __init__(self, x, y, color="red"):
+    assert isinstance(color, str)
+    self.x = ops.convert_to_tensor_or_composite(x)
+    self.y = ops.convert_to_tensor_or_composite(y)
+    self.color = color
+
+
+class TwoCompositesSpec(type_spec.TypeSpec):
+  """A TypeSpec for the TwoTensors value type."""
+
+  def __init__(self, x_spec, y_spec, color="red"):
+    self.x_spec = x_spec
+    self.y_spec = y_spec
+    self.color = color
+
+  value_type = property(lambda self: TwoComposites)
+
+  @property
+  def _component_specs(self):
+    return (self.x_spec, self.y_spec)
+
+  def _to_components(self, value):
+    return (value.x, value.y)
+
+  def _from_components(self, components):
+    x, y = components
+    return TwoTensors(x, y, self.color)
+
+  def _serialize(self):
+    return (self.x_spec, self.y_spec, self.color)
+
+  @classmethod
+  def from_value(cls, value):
+    return cls(type_spec.type_spec_from_value(value.x),
+               type_spec.type_spec_from_value(value.y),
+               value.color)
+
+
+type_spec.register_type_spec_from_value_converter(
+    TwoComposites, TwoCompositesSpec.from_value)
+
+
 class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 
   @parameterized.named_parameters(
@@ -283,5 +333,21 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     spec = type_spec.type_spec_from_value(value)
     self.assertEqual(spec, TwoTensorsSpec.from_value(value))
 
+  def testNestedRagged(self):
+    # Check that TwoCompositeSpecs are compatible if one has a nested
+    # RaggedTensorSpec w/ ragged_rank=0 and the other has a corresponding
+    # nested TensorSpec.
+    spec1 = TwoCompositesSpec(
+        ragged_tensor.RaggedTensorSpec([10], dtypes.int32, ragged_rank=0),
+        tensor_spec.TensorSpec(None, dtypes.int32))
+    spec2 = TwoCompositesSpec(
+        tensor_spec.TensorSpec([10], dtypes.int32),
+        tensor_spec.TensorSpec(None, dtypes.int32))
+    spec3 = TwoCompositesSpec(
+        tensor_spec.TensorSpec([12], dtypes.int32),
+        tensor_spec.TensorSpec(None, dtypes.int32))
+    self.assertTrue(spec1.is_compatible_with(spec2))
+    self.assertFalse(spec1.is_compatible_with(spec3))
+
 if __name__ == "__main__":
   googletest.main()