In DatasetSpec constructor, convert dataset_shape to a shape.

PiperOrigin-RevId: 258275543
This commit is contained in:
Edward Loper 2019-07-15 18:00:32 -07:00 committed by TensorFlower Gardener
parent 280336728a
commit 814cce54c9
5 changed files with 16 additions and 8 deletions

View File

@ -2449,12 +2449,9 @@ class DatasetSpec(type_spec.BatchableTypeSpec):
__slots__ = ["_element_spec", "_dataset_shape"]
def __init__(self, element_spec, dataset_shape=None):
def __init__(self, element_spec, dataset_shape=()):
self._element_spec = element_spec
if dataset_shape:
self._dataset_shape = dataset_shape
else:
self._dataset_shape = tensor_shape.TensorShape([])
self._dataset_shape = tensor_shape.as_shape(dataset_shape)
@property
def value_type(self):

View File

@ -22,6 +22,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
@ -707,6 +708,16 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
# pylint: enable=g-long-lambda
def testDatasetSpecConstructor(self):
rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
self.assertEqual(ds_struct._element_spec, element_spec)
# Note: shape was automatically converted from a list to a TensorShape.
self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
if __name__ == "__main__":
test.main()

View File

@ -10,7 +10,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
}
member_method {
name: "from_value"

View File

@ -10,7 +10,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
}
member_method {
name: "from_value"

View File

@ -10,7 +10,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
}
member_method {
name: "from_value"