In DatasetSpec constructor, convert dataset_shape to a shape.
PiperOrigin-RevId: 258275543
This commit is contained in:
parent
280336728a
commit
814cce54c9
@ -2449,12 +2449,9 @@ class DatasetSpec(type_spec.BatchableTypeSpec):
|
||||
|
||||
__slots__ = ["_element_spec", "_dataset_shape"]
|
||||
|
||||
def __init__(self, element_spec, dataset_shape=None):
|
||||
def __init__(self, element_spec, dataset_shape=()):
|
||||
self._element_spec = element_spec
|
||||
if dataset_shape:
|
||||
self._dataset_shape = dataset_shape
|
||||
else:
|
||||
self._dataset_shape = tensor_shape.TensorShape([])
|
||||
self._dataset_shape = tensor_shape.as_shape(dataset_shape)
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
|
@ -22,6 +22,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -707,6 +708,16 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
def testDatasetSpecConstructor(self):
|
||||
rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
|
||||
st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
|
||||
t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
|
||||
element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
|
||||
ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
|
||||
self.assertEqual(ds_struct._element_spec, element_spec)
|
||||
# Note: shape was automatically converted from a list to a TensorShape.
|
||||
self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -10,7 +10,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
|
@ -10,7 +10,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
|
@ -10,7 +10,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
|
Loading…
Reference in New Issue
Block a user