From c1dbd1682246b42450921db8681fec5040aaa719 Mon Sep 17 00:00:00 2001
From: Jiri Simsa <jsimsa@google.com>
Date: Tue, 21 Apr 2020 20:20:47 -0700
Subject: [PATCH] [tf.data] Fix an issue where `tf.data.DatasetSpec` could not
 be specified in `input_signature` of tf.function.

Fixes: https://github.com/tensorflow/tensorflow/issues/38733
PiperOrigin-RevId: 307733846
Change-Id: I28b7a4372fc585f8894df9928e3e56844429e260
---
 tensorflow/python/data/kernel_tests/BUILD     | 15 ++++++
 .../data/kernel_tests/dataset_spec_test.py    | 54 +++++++++++++++++++
 tensorflow/python/data/ops/dataset_ops.py     |  2 +-
 tensorflow/python/framework/type_spec.py      |  6 ++-
 tensorflow/python/util/nest.py                |  2 +-
 5 files changed, 76 insertions(+), 3 deletions(-)
 create mode 100644 tensorflow/python/data/kernel_tests/dataset_spec_test.py

diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 5b5f137afb2..ec567b8c3b4 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -117,6 +117,21 @@ tf_py_test(
     ],
 )
 
+tf_py_test(
+    name = "dataset_spec_test",
+    size = "small",
+    srcs = ["dataset_spec_test.py"],
+    deps = [
+        ":test_base",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:errors",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//third_party/py/numpy",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
 tf_py_test(
     name = "enumerate_test",
     size = "small",
diff --git a/tensorflow/python/data/kernel_tests/dataset_spec_test.py b/tensorflow/python/data/kernel_tests/dataset_spec_test.py
new file mode 100644
index 00000000000..781a972ea33
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/dataset_spec_test.py
@@ -0,0 +1,54 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.DatasetSpec`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import combinations
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.platform import test
+
+
+class DatasetSpecTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testInputSignature(self):
+    dataset = dataset_ops.Dataset.from_tensor_slices(
+        np.arange(10).astype(np.int32)).batch(5)
+
+    @def_function.function(input_signature=[
+        dataset_ops.DatasetSpec(
+            tensor_spec.TensorSpec(
+                shape=(None,), dtype=dtypes.int32, name=None),
+            tensor_shape.TensorShape([]))
+    ])
+    def fn(_):
+      pass
+
+    fn(dataset)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 7dcec3248ce..eb7963da332 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -3050,7 +3050,7 @@ class DatasetSpec(type_spec.BatchableTypeSpec):
 
   @property
   def value_type(self):
-    return _VariantDataset
+    return Dataset
 
   def _serialize(self):
     return (self._element_spec, self._dataset_shape)
diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py
index 490574bbc1b..8da3265e810 100644
--- a/tensorflow/python/framework/type_spec.py
+++ b/tensorflow/python/framework/type_spec.py
@@ -83,7 +83,11 @@ class TypeSpec(object):
 
   @abc.abstractproperty
   def value_type(self):
-    """The Python type for values that are compatible with this TypeSpec."""
+    """The Python type for values that are compatible with this TypeSpec.
+
+    In particular, all values that are compatible with this TypeSpec must be an
+    instance of this type.
+    """
     raise NotImplementedError("%s.value_type" % type(self).__name__)
 
   def is_compatible_with(self, spec_or_value):
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 517030193de..695cc4cc909 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -231,7 +231,7 @@ def _yield_sorted_items(iterable):
       yield field, getattr(iterable, field)
   elif _is_composite_tensor(iterable):
     type_spec = iterable._type_spec  # pylint: disable=protected-access
-    yield type(iterable).__name__, type_spec._to_components(iterable)  # pylint: disable=protected-access
+    yield type_spec.value_type.__name__, type_spec._to_components(iterable)  # pylint: disable=protected-access
   elif _is_type_spec(iterable):
     # Note: to allow CompositeTensors and their TypeSpecs to have matching
     # structures, we need to use the same key string here.