From 814cce54c9e096891245d399e67024f854583cc8 Mon Sep 17 00:00:00 2001 From: Edward Loper <edloper@google.com> Date: Mon, 15 Jul 2019 18:00:32 -0700 Subject: [PATCH] In DatasetSpec constructor, convert dataset_shape to a shape. PiperOrigin-RevId: 258275543 --- tensorflow/python/data/ops/dataset_ops.py | 7 ++----- tensorflow/python/data/util/structure_test.py | 11 +++++++++++ .../api/golden/v1/tensorflow.data.-dataset-spec.pbtxt | 2 +- ...sorflow.data.experimental.-dataset-structure.pbtxt | 2 +- .../api/golden/v2/tensorflow.data.-dataset-spec.pbtxt | 2 +- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index f82231cb856..f9c42df17d3 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2449,12 +2449,9 @@ class DatasetSpec(type_spec.BatchableTypeSpec): __slots__ = ["_element_spec", "_dataset_shape"] - def __init__(self, element_spec, dataset_shape=None): + def __init__(self, element_spec, dataset_shape=()): self._element_spec = element_spec - if dataset_shape: - self._dataset_shape = dataset_shape - else: - self._dataset_shape = tensor_shape.TensorShape([]) + self._dataset_shape = tensor_shape.as_shape(dataset_shape) @property def value_type(self): diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py index abd725a8413..8781a1933c5 100644 --- a/tensorflow/python/data/util/structure_test.py +++ b/tensorflow/python/data/util/structure_test.py @@ -22,6 +22,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import structure from tensorflow.python.framework import constant_op @@ -707,6 +708,16 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase, # pylint: enable=g-long-lambda + def testDatasetSpecConstructor(self): + rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32) + st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32) + t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string) + element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec} + ds_struct = dataset_ops.DatasetSpec(element_spec, [5]) + self.assertEqual(ds_struct._element_spec, element_spec) + # Note: shape was automatically converted from a list to a TensorShape. + self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5])) + if __name__ == "__main__": test.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt index 25bce513504..369aef45e9f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt @@ -10,7 +10,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], " } member_method { name: "from_value" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt index b4801277012..474c725a696 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt @@ -10,7 +10,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], " } member_method { name: "from_value" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt index 25bce513504..369aef45e9f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt @@ -10,7 +10,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], " } member_method { name: "from_value"