From 92b4656a723b9e506f135a9dbe68d20800092d1c Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Wed, 3 Apr 2019 16:57:33 -0700 Subject: [PATCH] [tf.data] Experimental API `tf.data.experimental.{from,to}_variant()` for creating a `tf.data.Dataset` from the variant tensor and representing a `tf.data.Dataset` as a variant tensor respectively. This CL also exposes an experimental API `tf.data.experimental.get_structure()` for obtaining the structure of a dataset element, which is required to create a `tf.data.Dataset` from a variant tensor (as the tensor itself does not contain the structure information). PiperOrigin-RevId: 241834334 --- .../python/data/experimental/__init__.py | 6 +++ .../data/experimental/kernel_tests/BUILD | 12 +++++ .../experimental/kernel_tests/sleep_test.py | 2 - .../experimental/kernel_tests/variant_test.py | 48 +++++++++++++++++++ .../python/data/experimental/ops/sleep.py | 2 +- tensorflow/python/data/ops/dataset_ops.py | 31 +++++++++++- .../v1/tensorflow.data.experimental.pbtxt | 12 +++++ .../v2/tensorflow.data.experimental.pbtxt | 12 +++++ 8 files changed, 120 insertions(+), 5 deletions(-) create mode 100644 tensorflow/python/data/experimental/kernel_tests/variant_test.py diff --git a/tensorflow/python/data/experimental/__init__.py b/tensorflow/python/data/experimental/__init__.py index 98454f3ac7d..3013eff4d7b 100644 --- a/tensorflow/python/data/experimental/__init__.py +++ b/tensorflow/python/data/experimental/__init__.py @@ -49,8 +49,10 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@copy_to_device @@dense_to_sparse_batch @@enumerate_dataset +@@from_variant @@get_next_as_optional @@get_single_element +@@get_structure @@group_by_reducer @@group_by_window @@ignore_errors @@ -68,6 +70,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@scan @@shuffle_and_repeat @@take_while +@@to_variant @@unbatch @@unique @@ -125,6 +128,9 @@ from tensorflow.python.data.experimental.ops.threading_options import ThreadingO from tensorflow.python.data.experimental.ops.unique import unique from tensorflow.python.data.experimental.ops.writers import TFRecordWriter from tensorflow.python.data.ops.dataset_ops import DatasetStructure +from tensorflow.python.data.ops.dataset_ops import from_variant +from tensorflow.python.data.ops.dataset_ops import get_structure +from tensorflow.python.data.ops.dataset_ops import to_variant from tensorflow.python.data.ops.iterator_ops import get_next_as_optional from tensorflow.python.data.ops.optional_ops import Optional from tensorflow.python.data.ops.optional_ops import OptionalStructure diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index 3eea06d27fe..317b54adca8 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -752,6 +752,18 @@ py_test( ], ) +py_test( + name = "variant_test", + srcs = ["variant_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + cuda_py_test( name = "wrap_unwrap_test", size = "small", diff --git a/tensorflow/python/data/experimental/kernel_tests/sleep_test.py b/tensorflow/python/data/experimental/kernel_tests/sleep_test.py index 4733c2a8330..f438b7358c4 100644 --- a/tensorflow/python/data/experimental/kernel_tests/sleep_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/sleep_test.py @@ -26,8 +26,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.platform import test -_NUMPY_RANDOM_SEED = 42 - @test_util.run_all_in_graph_and_eager_modes class SleepTest(test_base.DatasetTestBase): diff --git a/tensorflow/python/data/experimental/kernel_tests/variant_test.py b/tensorflow/python/data/experimental/kernel_tests/variant_test.py new file mode 100644 index 00000000000..6a3a1424d12 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/variant_test.py @@ -0,0 +1,48 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `tf.data.experimental.{from,to}_variant()`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import cardinality +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +@test_util.run_all_in_graph_and_eager_modes +class VariantTest(test_base.DatasetTestBase): + + def testRoundtripRange(self): + dataset = dataset_ops.Dataset.range(10) + variant = dataset_ops.to_variant(dataset) + dataset = dataset_ops.from_variant(variant, + dataset_ops.get_structure(dataset)) + self.assertDatasetProduces(dataset, range(10)) + self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10) + + def testRoundtripMap(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x) + variant = dataset_ops.to_variant(dataset) + dataset = dataset_ops.from_variant(variant, + dataset_ops.get_structure(dataset)) + self.assertDatasetProduces(dataset, [x * x for x in range(10)]) + self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/ops/sleep.py b/tensorflow/python/data/experimental/ops/sleep.py index b66edc7a194..15913aabef6 100644 --- a/tensorflow/python/data/experimental/ops/sleep.py +++ b/tensorflow/python/data/experimental/ops/sleep.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Experimental API for controlling threading in `tf.data` pipelines.""" +"""Experimental API for manually injecting delays into `tf.data` pipelines.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 378756cf1f9..6593d4f544a 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1874,8 +1874,7 @@ def make_initializable_iterator(dataset, shared_name=None): return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access -# TODO(b/110122868): Replace this method with a public API for reflecting on -# dataset structure. +@tf_export("data.experimental.get_structure") def get_structure(dataset_or_iterator): """Returns the `tf.data.experimental.Structure` of a `Dataset` or `Iterator`. @@ -2172,6 +2171,34 @@ class _VariantDataset(DatasetV2): return self._structure +@tf_export("data.experimental.from_variant") +def from_variant(variant, structure): + """Constructs a dataset from the given variant and structure. + + Args: + variant: A scalar `tf.variant` tensor representing a dataset. + structure: A `tf.data.experimental.Structure` object representing the + structure of each element in the dataset. + + Returns: + A `tf.data.Dataset` instance. + """ + return _VariantDataset(variant, structure) # pylint: disable=protected-access + + +@tf_export("data.experimental.to_variant") +def to_variant(dataset): + """Returns a variant representing the given dataset. + + Args: + dataset: A `tf.data.Dataset`. + + Returns: + A scalar `tf.variant` tensor representing the given dataset. + """ + return dataset._variant_tensor # pylint: disable=protected-access + + @tf_export("data.experimental.DatasetStructure") class DatasetStructure(structure_lib.Structure): """Represents a `Dataset` of structured values.""" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt index d62c67cfb2a..6cec9ac90a2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt @@ -116,6 +116,10 @@ tf_module { name: "enumerate_dataset" argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], " } + member_method { + name: "from_variant" + argspec: "args=[\'variant\', \'structure\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_next_as_optional" argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None" @@ -124,6 +128,10 @@ tf_module { name: "get_single_element" argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_structure" + argspec: "args=[\'dataset_or_iterator\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group_by_reducer" argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None" @@ -192,6 +200,10 @@ tf_module { name: "take_while" argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "to_variant" + argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "unbatch" argspec: "args=[], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt index 1d7ac210305..801ef083886 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt @@ -116,6 +116,10 @@ tf_module { name: "enumerate_dataset" argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], " } + member_method { + name: "from_variant" + argspec: "args=[\'variant\', \'structure\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_next_as_optional" argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None" @@ -124,6 +128,10 @@ tf_module { name: "get_single_element" argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_structure" + argspec: "args=[\'dataset_or_iterator\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "group_by_reducer" argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None" @@ -188,6 +196,10 @@ tf_module { name: "take_while" argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "to_variant" + argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "unbatch" argspec: "args=[], varargs=None, keywords=None, defaults=None"