[tf.data] Fix an issue where tf.data.DatasetSpec could not be specified in input_signature of tf.function.

Fixes: https://github.com/tensorflow/tensorflow/issues/38733
PiperOrigin-RevId: 307733846
Change-Id: I28b7a4372fc585f8894df9928e3e56844429e260
This commit is contained in:
Jiri Simsa 2020-04-21 20:20:47 -07:00 committed by TensorFlower Gardener
parent 47ea7eeb96
commit c1dbd16822
5 changed files with 76 additions and 3 deletions

View File

@ -117,6 +117,21 @@ tf_py_test(
],
)
tf_py_test(
name = "dataset_spec_test",
size = "small",
srcs = ["dataset_spec_test.py"],
deps = [
":test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "enumerate_test",
size = "small",

View File

@ -0,0 +1,54 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.DatasetSpec`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import def_function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.platform import test
class DatasetSpecTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testInputSignature(self):
dataset = dataset_ops.Dataset.from_tensor_slices(
np.arange(10).astype(np.int32)).batch(5)
@def_function.function(input_signature=[
dataset_ops.DatasetSpec(
tensor_spec.TensorSpec(
shape=(None,), dtype=dtypes.int32, name=None),
tensor_shape.TensorShape([]))
])
def fn(_):
pass
fn(dataset)
if __name__ == "__main__":
test.main()

View File

@ -3050,7 +3050,7 @@ class DatasetSpec(type_spec.BatchableTypeSpec):
@property
def value_type(self):
return _VariantDataset
return Dataset
def _serialize(self):
return (self._element_spec, self._dataset_shape)

View File

@ -83,7 +83,11 @@ class TypeSpec(object):
@abc.abstractproperty
def value_type(self):
"""The Python type for values that are compatible with this TypeSpec."""
"""The Python type for values that are compatible with this TypeSpec.
In particular, all values that are compatible with this TypeSpec must be an
instance of this type.
"""
raise NotImplementedError("%s.value_type" % type(self).__name__)
def is_compatible_with(self, spec_or_value):

View File

@ -231,7 +231,7 @@ def _yield_sorted_items(iterable):
yield field, getattr(iterable, field)
elif _is_composite_tensor(iterable):
type_spec = iterable._type_spec # pylint: disable=protected-access
yield type(iterable).__name__, type_spec._to_components(iterable) # pylint: disable=protected-access
yield type_spec.value_type.__name__, type_spec._to_components(iterable) # pylint: disable=protected-access
elif _is_type_spec(iterable):
# Note: to allow CompositeTensors and their TypeSpecs to have matching
# structures, we need to use the same key string here.