From c1dbd1682246b42450921db8681fec5040aaa719 Mon Sep 17 00:00:00 2001 From: Jiri Simsa <jsimsa@google.com> Date: Tue, 21 Apr 2020 20:20:47 -0700 Subject: [PATCH] [tf.data] Fix an issue where `tf.data.DatasetSpec` could not be specified in `input_signature` of tf.function. Fixes: https://github.com/tensorflow/tensorflow/issues/38733 PiperOrigin-RevId: 307733846 Change-Id: I28b7a4372fc585f8894df9928e3e56844429e260 --- tensorflow/python/data/kernel_tests/BUILD | 15 ++++++ .../data/kernel_tests/dataset_spec_test.py | 54 +++++++++++++++++++ tensorflow/python/data/ops/dataset_ops.py | 2 +- tensorflow/python/framework/type_spec.py | 6 ++- tensorflow/python/util/nest.py | 2 +- 5 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 tensorflow/python/data/kernel_tests/dataset_spec_test.py diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 5b5f137afb2..ec567b8c3b4 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -117,6 +117,21 @@ tf_py_test( ], ) +tf_py_test( + name = "dataset_spec_test", + size = "small", + srcs = ["dataset_spec_test.py"], + deps = [ + ":test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + tf_py_test( name = "enumerate_test", size = "small", diff --git a/tensorflow/python/data/kernel_tests/dataset_spec_test.py b/tensorflow/python/data/kernel_tests/dataset_spec_test.py new file mode 100644 index 00000000000..781a972ea33 --- /dev/null +++ b/tensorflow/python/data/kernel_tests/dataset_spec_test.py @@ -0,0 +1,54 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `tf.data.DatasetSpec`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import def_function +from tensorflow.python.framework import combinations +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.platform import test + + +class DatasetSpecTest(test_base.DatasetTestBase, parameterized.TestCase): + + @combinations.generate(test_base.default_test_combinations()) + def testInputSignature(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + np.arange(10).astype(np.int32)).batch(5) + + @def_function.function(input_signature=[ + dataset_ops.DatasetSpec( + tensor_spec.TensorSpec( + shape=(None,), dtype=dtypes.int32, name=None), + tensor_shape.TensorShape([])) + ]) + def fn(_): + pass + + fn(dataset) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 7dcec3248ce..eb7963da332 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -3050,7 +3050,7 @@ class DatasetSpec(type_spec.BatchableTypeSpec): @property def value_type(self): - return _VariantDataset + return Dataset def _serialize(self): return (self._element_spec, self._dataset_shape) diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py index 490574bbc1b..8da3265e810 100644 --- a/tensorflow/python/framework/type_spec.py +++ b/tensorflow/python/framework/type_spec.py @@ -83,7 +83,11 @@ class TypeSpec(object): @abc.abstractproperty def value_type(self): - """The Python type for values that are compatible with this TypeSpec.""" + """The Python type for values that are compatible with this TypeSpec. + + In particular, all values that are compatible with this TypeSpec must be an + instance of this type. + """ raise NotImplementedError("%s.value_type" % type(self).__name__) def is_compatible_with(self, spec_or_value): diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 517030193de..695cc4cc909 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -231,7 +231,7 @@ def _yield_sorted_items(iterable): yield field, getattr(iterable, field) elif _is_composite_tensor(iterable): type_spec = iterable._type_spec # pylint: disable=protected-access - yield type(iterable).__name__, type_spec._to_components(iterable) # pylint: disable=protected-access + yield type_spec.value_type.__name__, type_spec._to_components(iterable) # pylint: disable=protected-access elif _is_type_spec(iterable): # Note: to allow CompositeTensors and their TypeSpecs to have matching # structures, we need to use the same key string here.