Add a .element_spec property to DatasetSpec

Currently there is no way of recovering the inner spec.

This is an issue for example when using nested datasets when `ds.element_spec` will return the outer dataset spec `{'nested_ds': DatasetSpec({'img': TensorSpec(...)})})` but it's not possible to access the inner DatasetSpec.

PiperOrigin-RevId: 353636988
Change-Id: I4bcfb3ab31a0761834a2837075264f8117973861
This commit is contained in:
Etienne Pot 2021-01-25 06:51:18 -08:00 committed by TensorFlower Gardener
parent 24041b6681
commit 5a421f2ad0
6 changed files with 26 additions and 0 deletions

View File

@ -56,6 +56,9 @@
the dataset elements. This avoids the need for explicitly specifying the
`element_spec` argument of `tf.data.experimental.load` when loading the
previously saved dataset.
* Add `.element_spec` property to `tf.data.DatasetSpec` to access the
inner spec. This can be used to extract the structure of nested
datasets.
* XLA compilation:
* `tf.function(experimental_compile=True)` has become a stable API,
renamed `tf.function(jit_compile=True)`.

View File

@ -49,6 +49,12 @@ class DatasetSpecTest(test_base.DatasetTestBase, parameterized.TestCase):
fn(dataset)
@combinations.generate(test_base.default_test_combinations())
def testDatasetSpecInnerSpec(self):
inner_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32)
ds_spec = dataset_ops.DatasetSpec(inner_spec)
self.assertEqual(ds_spec.element_spec, inner_spec)
if __name__ == "__main__":
test.main()

View File

@ -3285,6 +3285,11 @@ class DatasetSpec(type_spec.BatchableTypeSpec):
def value_type(self):
return Dataset
@property
def element_spec(self):
"""The inner element spec."""
return self._element_spec
def _serialize(self):
return (self._element_spec, self._dataset_shape)

View File

@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "value_type"
mtype: "<type \'property\'>"

View File

@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "value_type"
mtype: "<type \'property\'>"

View File

@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "element_spec"
mtype: "<type \'property\'>"
}
member {
name: "value_type"
mtype: "<type \'property\'>"