diff --git a/RELEASE.md b/RELEASE.md index f84f7614041..e49397d1dcf 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -56,6 +56,9 @@ the dataset elements. This avoids the need for explicitly specifying the `element_spec` argument of `tf.data.experimental.load` when loading the previously saved dataset. + * Add `.element_spec` property to `tf.data.DatasetSpec` to access the + inner spec. This can be used to extract the structure of nested + datasets. * XLA compilation: * `tf.function(experimental_compile=True)` has become a stable API, renamed `tf.function(jit_compile=True)`. diff --git a/tensorflow/python/data/kernel_tests/dataset_spec_test.py b/tensorflow/python/data/kernel_tests/dataset_spec_test.py index 781a972ea33..1053b0e4a4e 100644 --- a/tensorflow/python/data/kernel_tests/dataset_spec_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_spec_test.py @@ -49,6 +49,12 @@ class DatasetSpecTest(test_base.DatasetTestBase, parameterized.TestCase): fn(dataset) + @combinations.generate(test_base.default_test_combinations()) + def testDatasetSpecInnerSpec(self): + inner_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32) + ds_spec = dataset_ops.DatasetSpec(inner_spec) + self.assertEqual(ds_spec.element_spec, inner_spec) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index acd2af7bd0b..1c949254f48 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -3285,6 +3285,11 @@ class DatasetSpec(type_spec.BatchableTypeSpec): def value_type(self): return Dataset + @property + def element_spec(self): + """The inner element spec.""" + return self._element_spec + def _serialize(self): return (self._element_spec, self._dataset_shape) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt index 369aef45e9f..f56e1198f10 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset-spec.pbtxt @@ -4,6 +4,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "element_spec" + mtype: "" + } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt index 474c725a696..fc65345f061 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-dataset-structure.pbtxt @@ -4,6 +4,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "element_spec" + mtype: "" + } member { name: "value_type" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt index 369aef45e9f..f56e1198f10 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset-spec.pbtxt @@ -4,6 +4,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "element_spec" + mtype: "" + } member { name: "value_type" mtype: ""