Add a .element_spec
property to DatasetSpec
Currently there is no way of recovering the inner spec. This is an issue for example when using nested datasets when `ds.element_spec` will return the outer dataset spec `{'nested_ds': DatasetSpec({'img': TensorSpec(...)})})` but it's not possible to access the inner DatasetSpec. PiperOrigin-RevId: 353636988 Change-Id: I4bcfb3ab31a0761834a2837075264f8117973861
This commit is contained in:
parent
24041b6681
commit
5a421f2ad0
@ -56,6 +56,9 @@
|
|||||||
the dataset elements. This avoids the need for explicitly specifying the
|
the dataset elements. This avoids the need for explicitly specifying the
|
||||||
`element_spec` argument of `tf.data.experimental.load` when loading the
|
`element_spec` argument of `tf.data.experimental.load` when loading the
|
||||||
previously saved dataset.
|
previously saved dataset.
|
||||||
|
* Add `.element_spec` property to `tf.data.DatasetSpec` to access the
|
||||||
|
inner spec. This can be used to extract the structure of nested
|
||||||
|
datasets.
|
||||||
* XLA compilation:
|
* XLA compilation:
|
||||||
* `tf.function(experimental_compile=True)` has become a stable API,
|
* `tf.function(experimental_compile=True)` has become a stable API,
|
||||||
renamed `tf.function(jit_compile=True)`.
|
renamed `tf.function(jit_compile=True)`.
|
||||||
|
@ -49,6 +49,12 @@ class DatasetSpecTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
fn(dataset)
|
fn(dataset)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testDatasetSpecInnerSpec(self):
|
||||||
|
inner_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32)
|
||||||
|
ds_spec = dataset_ops.DatasetSpec(inner_spec)
|
||||||
|
self.assertEqual(ds_spec.element_spec, inner_spec)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -3285,6 +3285,11 @@ class DatasetSpec(type_spec.BatchableTypeSpec):
|
|||||||
def value_type(self):
|
def value_type(self):
|
||||||
return Dataset
|
return Dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def element_spec(self):
|
||||||
|
"""The inner element spec."""
|
||||||
|
return self._element_spec
|
||||||
|
|
||||||
def _serialize(self):
|
def _serialize(self):
|
||||||
return (self._element_spec, self._dataset_shape)
|
return (self._element_spec, self._dataset_shape)
|
||||||
|
|
||||||
|
@ -4,6 +4,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "value_type"
|
name: "value_type"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -4,6 +4,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "value_type"
|
name: "value_type"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -4,6 +4,10 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "element_spec"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "value_type"
|
name: "value_type"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
Loading…
Reference in New Issue
Block a user