Add a .element_spec
property to DatasetSpec
Currently there is no way of recovering the inner spec. This is an issue for example when using nested datasets when `ds.element_spec` will return the outer dataset spec `{'nested_ds': DatasetSpec({'img': TensorSpec(...)})})` but it's not possible to access the inner DatasetSpec. PiperOrigin-RevId: 353636988 Change-Id: I4bcfb3ab31a0761834a2837075264f8117973861
This commit is contained in:
parent
24041b6681
commit
5a421f2ad0
@ -56,6 +56,9 @@
|
||||
the dataset elements. This avoids the need for explicitly specifying the
|
||||
`element_spec` argument of `tf.data.experimental.load` when loading the
|
||||
previously saved dataset.
|
||||
* Add `.element_spec` property to `tf.data.DatasetSpec` to access the
|
||||
inner spec. This can be used to extract the structure of nested
|
||||
datasets.
|
||||
* XLA compilation:
|
||||
* `tf.function(experimental_compile=True)` has become a stable API,
|
||||
renamed `tf.function(jit_compile=True)`.
|
||||
|
@ -49,6 +49,12 @@ class DatasetSpecTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
fn(dataset)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDatasetSpecInnerSpec(self):
|
||||
inner_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32)
|
||||
ds_spec = dataset_ops.DatasetSpec(inner_spec)
|
||||
self.assertEqual(ds_spec.element_spec, inner_spec)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -3285,6 +3285,11 @@ class DatasetSpec(type_spec.BatchableTypeSpec):
|
||||
def value_type(self):
|
||||
return Dataset
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
"""The inner element spec."""
|
||||
return self._element_spec
|
||||
|
||||
def _serialize(self):
|
||||
return (self._element_spec, self._dataset_shape)
|
||||
|
||||
|
@ -4,6 +4,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,6 +4,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,6 +4,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user