[tf.data] Experimental API tf.data.experimental.{from,to}_variant()
for creating a tf.data.Dataset
from the variant tensor and representing a tf.data.Dataset
as a variant tensor respectively.
This CL also exposes an experimental API `tf.data.experimental.get_structure()` for obtaining the structure of a dataset element, which is required to create a `tf.data.Dataset` from a variant tensor (as the tensor itself does not contain the structure information). PiperOrigin-RevId: 241834334
This commit is contained in:
parent
40ebe5778e
commit
92b4656a72
@ -49,8 +49,10 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
||||
@@copy_to_device
|
||||
@@dense_to_sparse_batch
|
||||
@@enumerate_dataset
|
||||
@@from_variant
|
||||
@@get_next_as_optional
|
||||
@@get_single_element
|
||||
@@get_structure
|
||||
@@group_by_reducer
|
||||
@@group_by_window
|
||||
@@ignore_errors
|
||||
@ -68,6 +70,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
||||
@@scan
|
||||
@@shuffle_and_repeat
|
||||
@@take_while
|
||||
@@to_variant
|
||||
@@unbatch
|
||||
@@unique
|
||||
|
||||
@ -125,6 +128,9 @@ from tensorflow.python.data.experimental.ops.threading_options import ThreadingO
|
||||
from tensorflow.python.data.experimental.ops.unique import unique
|
||||
from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
|
||||
from tensorflow.python.data.ops.dataset_ops import DatasetStructure
|
||||
from tensorflow.python.data.ops.dataset_ops import from_variant
|
||||
from tensorflow.python.data.ops.dataset_ops import get_structure
|
||||
from tensorflow.python.data.ops.dataset_ops import to_variant
|
||||
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
|
||||
from tensorflow.python.data.ops.optional_ops import Optional
|
||||
from tensorflow.python.data.ops.optional_ops import OptionalStructure
|
||||
|
@ -752,6 +752,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "variant_test",
|
||||
srcs = ["variant_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "wrap_unwrap_test",
|
||||
size = "small",
|
||||
|
@ -26,8 +26,6 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
_NUMPY_RANDOM_SEED = 42
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class SleepTest(test_base.DatasetTestBase):
|
||||
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.{from,to}_variant()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class VariantTest(test_base.DatasetTestBase):
|
||||
|
||||
def testRoundtripRange(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
variant = dataset_ops.to_variant(dataset)
|
||||
dataset = dataset_ops.from_variant(variant,
|
||||
dataset_ops.get_structure(dataset))
|
||||
self.assertDatasetProduces(dataset, range(10))
|
||||
self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
|
||||
|
||||
def testRoundtripMap(self):
|
||||
dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x)
|
||||
variant = dataset_ops.to_variant(dataset)
|
||||
dataset = dataset_ops.from_variant(variant,
|
||||
dataset_ops.get_structure(dataset))
|
||||
self.assertDatasetProduces(dataset, [x * x for x in range(10)])
|
||||
self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Experimental API for controlling threading in `tf.data` pipelines."""
|
||||
"""Experimental API for manually injecting delays into `tf.data` pipelines."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
@ -1874,8 +1874,7 @@ def make_initializable_iterator(dataset, shared_name=None):
|
||||
return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access
|
||||
|
||||
|
||||
# TODO(b/110122868): Replace this method with a public API for reflecting on
|
||||
# dataset structure.
|
||||
@tf_export("data.experimental.get_structure")
|
||||
def get_structure(dataset_or_iterator):
|
||||
"""Returns the `tf.data.experimental.Structure` of a `Dataset` or `Iterator`.
|
||||
|
||||
@ -2172,6 +2171,34 @@ class _VariantDataset(DatasetV2):
|
||||
return self._structure
|
||||
|
||||
|
||||
@tf_export("data.experimental.from_variant")
|
||||
def from_variant(variant, structure):
|
||||
"""Constructs a dataset from the given variant and structure.
|
||||
|
||||
Args:
|
||||
variant: A scalar `tf.variant` tensor representing a dataset.
|
||||
structure: A `tf.data.experimental.Structure` object representing the
|
||||
structure of each element in the dataset.
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` instance.
|
||||
"""
|
||||
return _VariantDataset(variant, structure) # pylint: disable=protected-access
|
||||
|
||||
|
||||
@tf_export("data.experimental.to_variant")
|
||||
def to_variant(dataset):
|
||||
"""Returns a variant representing the given dataset.
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset`.
|
||||
|
||||
Returns:
|
||||
A scalar `tf.variant` tensor representing the given dataset.
|
||||
"""
|
||||
return dataset._variant_tensor # pylint: disable=protected-access
|
||||
|
||||
|
||||
@tf_export("data.experimental.DatasetStructure")
|
||||
class DatasetStructure(structure_lib.Structure):
|
||||
"""Represents a `Dataset` of structured values."""
|
||||
|
@ -116,6 +116,10 @@ tf_module {
|
||||
name: "enumerate_dataset"
|
||||
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_variant"
|
||||
argspec: "args=[\'variant\', \'structure\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_next_as_optional"
|
||||
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -124,6 +128,10 @@ tf_module {
|
||||
name: "get_single_element"
|
||||
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_structure"
|
||||
argspec: "args=[\'dataset_or_iterator\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group_by_reducer"
|
||||
argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -192,6 +200,10 @@ tf_module {
|
||||
name: "take_while"
|
||||
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "to_variant"
|
||||
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -116,6 +116,10 @@ tf_module {
|
||||
name: "enumerate_dataset"
|
||||
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_variant"
|
||||
argspec: "args=[\'variant\', \'structure\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_next_as_optional"
|
||||
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -124,6 +128,10 @@ tf_module {
|
||||
name: "get_single_element"
|
||||
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_structure"
|
||||
argspec: "args=[\'dataset_or_iterator\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group_by_reducer"
|
||||
argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -188,6 +196,10 @@ tf_module {
|
||||
name: "take_while"
|
||||
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "to_variant"
|
||||
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unbatch"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user