[tf.data] Experimental API tf.data.experimental.{from,to}_variant()
for creating a tf.data.Dataset
from the variant tensor and representing a tf.data.Dataset
as a variant tensor respectively.
This CL also exposes an experimental API `tf.data.experimental.get_structure()` for obtaining the structure of a dataset element, which is required to create a `tf.data.Dataset` from a variant tensor (as the tensor itself does not contain the structure information). PiperOrigin-RevId: 241834334
This commit is contained in:
parent
40ebe5778e
commit
92b4656a72
@ -49,8 +49,10 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
|||||||
@@copy_to_device
|
@@copy_to_device
|
||||||
@@dense_to_sparse_batch
|
@@dense_to_sparse_batch
|
||||||
@@enumerate_dataset
|
@@enumerate_dataset
|
||||||
|
@@from_variant
|
||||||
@@get_next_as_optional
|
@@get_next_as_optional
|
||||||
@@get_single_element
|
@@get_single_element
|
||||||
|
@@get_structure
|
||||||
@@group_by_reducer
|
@@group_by_reducer
|
||||||
@@group_by_window
|
@@group_by_window
|
||||||
@@ignore_errors
|
@@ignore_errors
|
||||||
@ -68,6 +70,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
|||||||
@@scan
|
@@scan
|
||||||
@@shuffle_and_repeat
|
@@shuffle_and_repeat
|
||||||
@@take_while
|
@@take_while
|
||||||
|
@@to_variant
|
||||||
@@unbatch
|
@@unbatch
|
||||||
@@unique
|
@@unique
|
||||||
|
|
||||||
@ -125,6 +128,9 @@ from tensorflow.python.data.experimental.ops.threading_options import ThreadingO
|
|||||||
from tensorflow.python.data.experimental.ops.unique import unique
|
from tensorflow.python.data.experimental.ops.unique import unique
|
||||||
from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
|
from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
|
||||||
from tensorflow.python.data.ops.dataset_ops import DatasetStructure
|
from tensorflow.python.data.ops.dataset_ops import DatasetStructure
|
||||||
|
from tensorflow.python.data.ops.dataset_ops import from_variant
|
||||||
|
from tensorflow.python.data.ops.dataset_ops import get_structure
|
||||||
|
from tensorflow.python.data.ops.dataset_ops import to_variant
|
||||||
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
|
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
|
||||||
from tensorflow.python.data.ops.optional_ops import Optional
|
from tensorflow.python.data.ops.optional_ops import Optional
|
||||||
from tensorflow.python.data.ops.optional_ops import OptionalStructure
|
from tensorflow.python.data.ops.optional_ops import OptionalStructure
|
||||||
|
@ -752,6 +752,18 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "variant_test",
|
||||||
|
srcs = ["variant_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "wrap_unwrap_test",
|
name = "wrap_unwrap_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -26,8 +26,6 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
_NUMPY_RANDOM_SEED = 42
|
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class SleepTest(test_base.DatasetTestBase):
|
class SleepTest(test_base.DatasetTestBase):
|
||||||
|
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for `tf.data.experimental.{from,to}_variant()`."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.ops import cardinality
|
||||||
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class VariantTest(test_base.DatasetTestBase):
|
||||||
|
|
||||||
|
def testRoundtripRange(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
variant = dataset_ops.to_variant(dataset)
|
||||||
|
dataset = dataset_ops.from_variant(variant,
|
||||||
|
dataset_ops.get_structure(dataset))
|
||||||
|
self.assertDatasetProduces(dataset, range(10))
|
||||||
|
self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
|
||||||
|
|
||||||
|
def testRoundtripMap(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x)
|
||||||
|
variant = dataset_ops.to_variant(dataset)
|
||||||
|
dataset = dataset_ops.from_variant(variant,
|
||||||
|
dataset_ops.get_structure(dataset))
|
||||||
|
self.assertDatasetProduces(dataset, [x * x for x in range(10)])
|
||||||
|
self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Experimental API for controlling threading in `tf.data` pipelines."""
|
"""Experimental API for manually injecting delays into `tf.data` pipelines."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
@ -1874,8 +1874,7 @@ def make_initializable_iterator(dataset, shared_name=None):
|
|||||||
return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access
|
return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/110122868): Replace this method with a public API for reflecting on
|
@tf_export("data.experimental.get_structure")
|
||||||
# dataset structure.
|
|
||||||
def get_structure(dataset_or_iterator):
|
def get_structure(dataset_or_iterator):
|
||||||
"""Returns the `tf.data.experimental.Structure` of a `Dataset` or `Iterator`.
|
"""Returns the `tf.data.experimental.Structure` of a `Dataset` or `Iterator`.
|
||||||
|
|
||||||
@ -2172,6 +2171,34 @@ class _VariantDataset(DatasetV2):
|
|||||||
return self._structure
|
return self._structure
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("data.experimental.from_variant")
|
||||||
|
def from_variant(variant, structure):
|
||||||
|
"""Constructs a dataset from the given variant and structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variant: A scalar `tf.variant` tensor representing a dataset.
|
||||||
|
structure: A `tf.data.experimental.Structure` object representing the
|
||||||
|
structure of each element in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `tf.data.Dataset` instance.
|
||||||
|
"""
|
||||||
|
return _VariantDataset(variant, structure) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("data.experimental.to_variant")
|
||||||
|
def to_variant(dataset):
|
||||||
|
"""Returns a variant representing the given dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: A `tf.data.Dataset`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar `tf.variant` tensor representing the given dataset.
|
||||||
|
"""
|
||||||
|
return dataset._variant_tensor # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.DatasetStructure")
|
@tf_export("data.experimental.DatasetStructure")
|
||||||
class DatasetStructure(structure_lib.Structure):
|
class DatasetStructure(structure_lib.Structure):
|
||||||
"""Represents a `Dataset` of structured values."""
|
"""Represents a `Dataset` of structured values."""
|
||||||
|
@ -116,6 +116,10 @@ tf_module {
|
|||||||
name: "enumerate_dataset"
|
name: "enumerate_dataset"
|
||||||
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_variant"
|
||||||
|
argspec: "args=[\'variant\', \'structure\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_next_as_optional"
|
name: "get_next_as_optional"
|
||||||
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
|
||||||
@ -124,6 +128,10 @@ tf_module {
|
|||||||
name: "get_single_element"
|
name: "get_single_element"
|
||||||
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_structure"
|
||||||
|
argspec: "args=[\'dataset_or_iterator\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group_by_reducer"
|
name: "group_by_reducer"
|
||||||
argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None"
|
||||||
@ -192,6 +200,10 @@ tf_module {
|
|||||||
name: "take_while"
|
name: "take_while"
|
||||||
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "to_variant"
|
||||||
|
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "unbatch"
|
name: "unbatch"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -116,6 +116,10 @@ tf_module {
|
|||||||
name: "enumerate_dataset"
|
name: "enumerate_dataset"
|
||||||
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_variant"
|
||||||
|
argspec: "args=[\'variant\', \'structure\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_next_as_optional"
|
name: "get_next_as_optional"
|
||||||
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
|
||||||
@ -124,6 +128,10 @@ tf_module {
|
|||||||
name: "get_single_element"
|
name: "get_single_element"
|
||||||
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_structure"
|
||||||
|
argspec: "args=[\'dataset_or_iterator\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "group_by_reducer"
|
name: "group_by_reducer"
|
||||||
argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None"
|
||||||
@ -188,6 +196,10 @@ tf_module {
|
|||||||
name: "take_while"
|
name: "take_while"
|
||||||
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'predicate\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "to_variant"
|
||||||
|
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "unbatch"
|
name: "unbatch"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
Reference in New Issue
Block a user