In DatasetSpec constructor, convert dataset_shape to a shape.
PiperOrigin-RevId: 258275543
This commit is contained in:
parent
280336728a
commit
814cce54c9
@ -2449,12 +2449,9 @@ class DatasetSpec(type_spec.BatchableTypeSpec):
|
|||||||
|
|
||||||
__slots__ = ["_element_spec", "_dataset_shape"]
|
__slots__ = ["_element_spec", "_dataset_shape"]
|
||||||
|
|
||||||
def __init__(self, element_spec, dataset_shape=None):
|
def __init__(self, element_spec, dataset_shape=()):
|
||||||
self._element_spec = element_spec
|
self._element_spec = element_spec
|
||||||
if dataset_shape:
|
self._dataset_shape = tensor_shape.as_shape(dataset_shape)
|
||||||
self._dataset_shape = dataset_shape
|
|
||||||
else:
|
|
||||||
self._dataset_shape = tensor_shape.TensorShape([])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_type(self):
|
def value_type(self):
|
||||||
|
@ -22,6 +22,7 @@ from absl.testing import parameterized
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.util import nest
|
from tensorflow.python.data.util import nest
|
||||||
from tensorflow.python.data.util import structure
|
from tensorflow.python.data.util import structure
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -707,6 +708,16 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
|||||||
|
|
||||||
# pylint: enable=g-long-lambda
|
# pylint: enable=g-long-lambda
|
||||||
|
|
||||||
|
def testDatasetSpecConstructor(self):
|
||||||
|
rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
|
||||||
|
st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
|
||||||
|
t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
|
||||||
|
element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
|
||||||
|
ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
|
||||||
|
self.assertEqual(ds_struct._element_spec, element_spec)
|
||||||
|
# Note: shape was automatically converted from a list to a TensorShape.
|
||||||
|
self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -10,7 +10,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "from_value"
|
name: "from_value"
|
||||||
|
@ -10,7 +10,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "from_value"
|
name: "from_value"
|
||||||
|
@ -10,7 +10,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'element_spec\', \'dataset_shape\'], varargs=None, keywords=None, defaults=[\'()\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "from_value"
|
name: "from_value"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user