[tf.data] Adding deprecation to experimental APIs for which a non-experimental alternative exists, removing the tf.data.experimental.filter_for_shard
API altogether if the favor of the tf.data.Dataset.shard
API.
PiperOrigin-RevId: 238030932
This commit is contained in:
parent
a4c589d5c7
commit
f1d30ce1be
@ -48,7 +48,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
|||||||
@@copy_to_device
|
@@copy_to_device
|
||||||
@@dense_to_sparse_batch
|
@@dense_to_sparse_batch
|
||||||
@@enumerate_dataset
|
@@enumerate_dataset
|
||||||
@@filter_for_shard
|
|
||||||
@@get_next_as_optional
|
@@get_next_as_optional
|
||||||
@@get_single_element
|
@@get_single_element
|
||||||
@@group_by_reducer
|
@@group_by_reducer
|
||||||
@ -92,7 +91,6 @@ from tensorflow.python.data.experimental.ops.cardinality import UNKNOWN as UNKNO
|
|||||||
from tensorflow.python.data.experimental.ops.counter import Counter
|
from tensorflow.python.data.experimental.ops.counter import Counter
|
||||||
from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
|
from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
|
||||||
from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
|
from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
|
||||||
from tensorflow.python.data.experimental.ops.filter_for_shard_ops import filter_for_shard
|
|
||||||
from tensorflow.python.data.experimental.ops.get_single_element import get_single_element
|
from tensorflow.python.data.experimental.ops.get_single_element import get_single_element
|
||||||
from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_length
|
from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_length
|
||||||
from tensorflow.python.data.experimental.ops.grouping import group_by_reducer
|
from tensorflow.python.data.experimental.ops.grouping import group_by_reducer
|
||||||
|
@ -162,18 +162,6 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "filter_for_shard_ops",
|
|
||||||
srcs = ["filter_for_shard_ops.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/python:dtypes",
|
|
||||||
"//tensorflow/python:math_ops",
|
|
||||||
"//tensorflow/python:ops",
|
|
||||||
"//tensorflow/python:tensor_util",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "error_ops",
|
name = "error_ops",
|
||||||
srcs = ["error_ops.py"],
|
srcs = ["error_ops.py"],
|
||||||
@ -466,7 +454,6 @@ py_library(
|
|||||||
":distribute",
|
":distribute",
|
||||||
":enumerate_ops",
|
":enumerate_ops",
|
||||||
":error_ops",
|
":error_ops",
|
||||||
":filter_for_shard_ops",
|
|
||||||
":get_single_element",
|
":get_single_element",
|
||||||
":grouping",
|
":grouping",
|
||||||
":indexed_dataset_ops",
|
":indexed_dataset_ops",
|
||||||
|
@ -665,6 +665,11 @@ def map_and_batch_with_legacy_function(map_func,
|
|||||||
return _apply_fn
|
return _apply_fn
|
||||||
|
|
||||||
|
|
||||||
|
@deprecation.deprecated(
|
||||||
|
None,
|
||||||
|
"Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by "
|
||||||
|
"`tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data "
|
||||||
|
"optimizations will take care of using the fused implementation.")
|
||||||
@tf_export("data.experimental.map_and_batch")
|
@tf_export("data.experimental.map_and_batch")
|
||||||
def map_and_batch(map_func,
|
def map_and_batch(map_func,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
@ -1,106 +0,0 @@
|
|||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Naive shard dataset transformation."""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.framework import tensor_util
|
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.filter_for_shard")
|
|
||||||
def filter_for_shard(num_shards, shard_index):
|
|
||||||
"""Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
|
|
||||||
|
|
||||||
This dataset operator is very useful when running distributed training, as
|
|
||||||
it allows each worker to read a unique subset.
|
|
||||||
|
|
||||||
When reading a single input file, you can skip elements as follows:
|
|
||||||
|
|
||||||
```python
|
|
||||||
d = tf.data.TFRecordDataset(FLAGS.input_file)
|
|
||||||
d = d.apply(tf.data.experimental.naive_shard(FLAGS.num_workers,
|
|
||||||
FLAGS.worker_index))
|
|
||||||
d = d.repeat(FLAGS.num_epochs)
|
|
||||||
d = d.shuffle(FLAGS.shuffle_buffer_size)
|
|
||||||
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
|
|
||||||
```
|
|
||||||
|
|
||||||
Important caveats:
|
|
||||||
|
|
||||||
- Be sure to shard before you use any randomizing operator (such as
|
|
||||||
shuffle).
|
|
||||||
- Generally it is best if the shard operator is used early in the dataset
|
|
||||||
pipeline. For example, when reading from a set of TFRecord files, shard
|
|
||||||
before converting the dataset to input samples. This avoids reading every
|
|
||||||
file on every worker. The following is an example of an efficient
|
|
||||||
sharding strategy within a complete pipeline:
|
|
||||||
|
|
||||||
```python
|
|
||||||
d = Dataset.list_files(FLAGS.pattern)
|
|
||||||
d = d.apply(tf.data.experimental.naive_shard(FLAGS.num_workers,
|
|
||||||
FLAGS.worker_index))
|
|
||||||
d = d.repeat(FLAGS.num_epochs)
|
|
||||||
d = d.shuffle(FLAGS.shuffle_buffer_size)
|
|
||||||
d = d.interleave(tf.data.TFRecordDataset,
|
|
||||||
cycle_length=FLAGS.num_readers, block_length=1)
|
|
||||||
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
|
||||||
shards operating in parallel.
|
|
||||||
shard_index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `Dataset` transformation function, which can be passed to
|
|
||||||
`tf.data.Dataset.apply`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `num_shards` or `shard_index` are illegal values. Note: error
|
|
||||||
checking is done on a best-effort basis, and errors aren't guaranteed to
|
|
||||||
be caught upon dataset creation. (e.g. providing in a placeholder tensor
|
|
||||||
bypasses the early checking, and will instead result in an error during
|
|
||||||
a session.run call.)
|
|
||||||
"""
|
|
||||||
num_shards = ops.convert_to_tensor(
|
|
||||||
num_shards, name="num_shards", dtype=dtypes.int64)
|
|
||||||
num_shards_static = tensor_util.constant_value(num_shards)
|
|
||||||
shard_index = ops.convert_to_tensor(shard_index, name="shard_index",
|
|
||||||
dtype=dtypes.int64)
|
|
||||||
shard_index_static = tensor_util.constant_value(shard_index)
|
|
||||||
|
|
||||||
if num_shards_static is not None and num_shards_static < 1:
|
|
||||||
raise ValueError("num_shards must be >= 1; got: %s" % num_shards_static)
|
|
||||||
if shard_index_static is not None and shard_index_static < 0:
|
|
||||||
raise ValueError("shard_index must be >= 0; got: %s" % shard_index_static)
|
|
||||||
if (shard_index_static is not None and num_shards_static is not None and
|
|
||||||
shard_index_static >= num_shards_static):
|
|
||||||
raise ValueError("shard_index must be < num_shards; %s is not < %s" %
|
|
||||||
(shard_index_static, num_shards_static))
|
|
||||||
|
|
||||||
def filter_fn(elem_index, _):
|
|
||||||
mod_result = math_ops.mod(elem_index, num_shards)
|
|
||||||
return math_ops.equal(mod_result, shard_index)
|
|
||||||
|
|
||||||
def _apply_fn(dataset):
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
return dataset._enumerate().filter(filter_fn).map(lambda _, elem: elem)
|
|
||||||
|
|
||||||
return _apply_fn
|
|
@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||||
from tensorflow.python.ops import gen_stateless_random_ops
|
from tensorflow.python.ops import gen_stateless_random_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -82,6 +83,11 @@ class _ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
|||||||
return "tf.data.experimental.parallel_interleave()"
|
return "tf.data.experimental.parallel_interleave()"
|
||||||
|
|
||||||
|
|
||||||
|
@deprecation.deprecated(
|
||||||
|
None,
|
||||||
|
"Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, "
|
||||||
|
"num_parallel_calls=tf.data.experimental.AUTOTUNE)` instead. If sloppy "
|
||||||
|
"execution is desired, use `tf.data.Options.experimental_determinstic`.")
|
||||||
@tf_export("data.experimental.parallel_interleave")
|
@tf_export("data.experimental.parallel_interleave")
|
||||||
def parallel_interleave(map_func,
|
def parallel_interleave(map_func,
|
||||||
cycle_length,
|
cycle_length,
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gen_dataset_ops
|
from tensorflow.python.ops import gen_dataset_ops
|
||||||
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -50,6 +51,11 @@ class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
|||||||
variant_tensor)
|
variant_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecation.deprecated(
|
||||||
|
None,
|
||||||
|
"Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by "
|
||||||
|
"`tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take "
|
||||||
|
"care of using the fused implementation.")
|
||||||
@tf_export("data.experimental.shuffle_and_repeat")
|
@tf_export("data.experimental.shuffle_and_repeat")
|
||||||
def shuffle_and_repeat(buffer_size, count=None, seed=None):
|
def shuffle_and_repeat(buffer_size, count=None, seed=None):
|
||||||
"""Shuffles and repeats a Dataset returning a new permutation for each epoch.
|
"""Shuffles and repeats a Dataset returning a new permutation for each epoch.
|
||||||
|
@ -26,7 +26,6 @@ py_library(
|
|||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
"//tensorflow/python:tensor_util",
|
"//tensorflow/python:tensor_util",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/data/experimental/ops:filter_for_shard_ops",
|
|
||||||
"//tensorflow/python/data/experimental/ops:optimization_options",
|
"//tensorflow/python/data/experimental/ops:optimization_options",
|
||||||
"//tensorflow/python/data/experimental/ops:stats_options",
|
"//tensorflow/python/data/experimental/ops:stats_options",
|
||||||
"//tensorflow/python/data/experimental/ops:threading_options",
|
"//tensorflow/python/data/experimental/ops:threading_options",
|
||||||
|
@ -368,7 +368,6 @@ py_library(
|
|||||||
srcs = ["input_ops.py"],
|
srcs = ["input_ops.py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python/data/experimental/ops:filter_for_shard_ops",
|
|
||||||
"//tensorflow/python/data/util:nest",
|
"//tensorflow/python/data/util:nest",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -40,10 +40,9 @@ def auto_shard_dataset(dataset, num_shards, index):
|
|||||||
dataset: A `tf.data.Dataset` instance, typically the result of a bunch of
|
dataset: A `tf.data.Dataset` instance, typically the result of a bunch of
|
||||||
dataset transformations.
|
dataset transformations.
|
||||||
num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
||||||
shards operating in parallel. Same usage as in
|
shards operating in parallel. Same usage as in `tf.data.Dataset.shard`.
|
||||||
`tf.data.experimental.filter_for_shard`.
|
|
||||||
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
|
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
|
||||||
Same usage as in `Dataset.shard`.
|
Same usage as in `tf.data.Dataset.shard`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A modified `Dataset` obtained by updating the pipeline sharded by the
|
A modified `Dataset` obtained by updating the pipeline sharded by the
|
||||||
|
@ -112,10 +112,6 @@ tf_module {
|
|||||||
name: "enumerate_dataset"
|
name: "enumerate_dataset"
|
||||||
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "filter_for_shard"
|
|
||||||
argspec: "args=[\'num_shards\', \'shard_index\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_next_as_optional"
|
name: "get_next_as_optional"
|
||||||
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -112,10 +112,6 @@ tf_module {
|
|||||||
name: "enumerate_dataset"
|
name: "enumerate_dataset"
|
||||||
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "filter_for_shard"
|
|
||||||
argspec: "args=[\'num_shards\', \'shard_index\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_next_as_optional"
|
name: "get_next_as_optional"
|
||||||
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user