From 814cce54c9e096891245d399e67024f854583cc8 Mon Sep 17 00:00:00 2001
From: Edward Loper <edloper@google.com>
Date: Mon, 15 Jul 2019 18:00:32 -0700
Subject: [PATCH] In DatasetSpec constructor, convert dataset_shape to a shape.

PiperOrigin-RevId: 258275543
---
 tensorflow/python/data/ops/dataset_ops.py             |  7 ++-----
 tensorflow/python/data/util/structure_test.py         | 11 +++++++++++
 .../api/golden/v1/tensorflow.data.-dataset-spec.pbtxt |  2 +-
 ...sorflow.data.experimental.-dataset-structure.pbtxt |  2 +-
 .../api/golden/v2/tensorflow.data.-dataset-spec.pbtxt |  2 +-
 5 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index f82231cb856..f9c42df17d3 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -2449,12 +2449,9 @@ class DatasetSpec(type_spec.BatchableTypeSpec):
 
   __slots__ = ["_element_spec", "_dataset_shape"]
 
-  def __init__(self, element_spec, dataset_shape=None):
+  def __init__(self, element_spec, dataset_shape=()):
     self._element_spec = element_spec
-    if dataset_shape:
-      self._dataset_shape = dataset_shape
-    else:
-      self._dataset_shape = tensor_shape.TensorShape([])
+    self._dataset_shape = tensor_shape.as_shape(dataset_shape)
 
   @property
   def value_type(self):
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
index abd725a8413..8781a1933c5 100644
--- a/tensorflow/python/data/util/structure_test.py
+++ b/tensorflow/python/data/util/structure_test.py
@@ -22,6 +22,7 @@ from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.data.util import structure
 from tensorflow.python.framework import constant_op
@@ -707,6 +708,16 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
 
   # pylint: enable=g-long-lambda
 
+  def testDatasetSpecConstructor(self):
+    rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
+    st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
+    t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
+    element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
+    ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
+    self.assertEqual(ds_struct._element_spec, element_spec)
+    # Note: shape was automatically converted from a list to a TensorShape.
+    self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt
index 25bce513504..369aef45e9f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt
@@ -10,7 +10,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
   }
   member_method {
     name: "from_value"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt
index b4801277012..474c725a696 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt
@@ -10,7 +10,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
   }
   member_method {
     name: "from_value"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt
index 25bce513504..369aef45e9f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt
@@ -10,7 +10,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
   }
   member_method {
     name: "from_value"