[tf.data] Adding deprecation to experimental APIs for which a non-experimental alternative exists, removing the tf.data.experimental.filter_for_shard API altogether if the favor of the tf.data.Dataset.shard API.

PiperOrigin-RevId: 238030932
This commit is contained in:
Jiri Simsa 2019-03-12 09:27:00 -07:00 committed by TensorFlower Gardener
parent a4c589d5c7
commit f1d30ce1be
11 changed files with 19 additions and 134 deletions

View File

@ -48,7 +48,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@copy_to_device
@@dense_to_sparse_batch
@@enumerate_dataset
@@filter_for_shard
@@get_next_as_optional
@@get_single_element
@@group_by_reducer
@ -92,7 +91,6 @@ from tensorflow.python.data.experimental.ops.cardinality import UNKNOWN as UNKNO
from tensorflow.python.data.experimental.ops.counter import Counter
from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
from tensorflow.python.data.experimental.ops.filter_for_shard_ops import filter_for_shard
from tensorflow.python.data.experimental.ops.get_single_element import get_single_element
from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_length
from tensorflow.python.data.experimental.ops.grouping import group_by_reducer

View File

@ -162,18 +162,6 @@ py_library(
],
)
py_library(
name = "filter_for_shard_ops",
srcs = ["filter_for_shard_ops.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
"//tensorflow/python:ops",
"//tensorflow/python:tensor_util",
],
)
py_library(
name = "error_ops",
srcs = ["error_ops.py"],
@ -466,7 +454,6 @@ py_library(
":distribute",
":enumerate_ops",
":error_ops",
":filter_for_shard_ops",
":get_single_element",
":grouping",
":indexed_dataset_ops",

View File

@ -665,6 +665,11 @@ def map_and_batch_with_legacy_function(map_func,
return _apply_fn
@deprecation.deprecated(
None,
"Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by "
"`tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data "
"optimizations will take care of using the fused implementation.")
@tf_export("data.experimental.map_and_batch")
def map_and_batch(map_func,
batch_size,

View File

@ -1,106 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Naive shard dataset transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.filter_for_shard")
def filter_for_shard(num_shards, shard_index):
"""Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
This dataset operator is very useful when running distributed training, as
it allows each worker to read a unique subset.
When reading a single input file, you can skip elements as follows:
```python
d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.apply(tf.data.experimental.naive_shard(FLAGS.num_workers,
FLAGS.worker_index))
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
```
Important caveats:
- Be sure to shard before you use any randomizing operator (such as
shuffle).
- Generally it is best if the shard operator is used early in the dataset
pipeline. For example, when reading from a set of TFRecord files, shard
before converting the dataset to input samples. This avoids reading every
file on every worker. The following is an example of an efficient
sharding strategy within a complete pipeline:
```python
d = Dataset.list_files(FLAGS.pattern)
d = d.apply(tf.data.experimental.naive_shard(FLAGS.num_workers,
FLAGS.worker_index))
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
cycle_length=FLAGS.num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
```
Args:
num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
shards operating in parallel.
shard_index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
Raises:
ValueError: if `num_shards` or `shard_index` are illegal values. Note: error
checking is done on a best-effort basis, and errors aren't guaranteed to
be caught upon dataset creation. (e.g. providing in a placeholder tensor
bypasses the early checking, and will instead result in an error during
a session.run call.)
"""
num_shards = ops.convert_to_tensor(
num_shards, name="num_shards", dtype=dtypes.int64)
num_shards_static = tensor_util.constant_value(num_shards)
shard_index = ops.convert_to_tensor(shard_index, name="shard_index",
dtype=dtypes.int64)
shard_index_static = tensor_util.constant_value(shard_index)
if num_shards_static is not None and num_shards_static < 1:
raise ValueError("num_shards must be >= 1; got: %s" % num_shards_static)
if shard_index_static is not None and shard_index_static < 0:
raise ValueError("shard_index must be >= 0; got: %s" % shard_index_static)
if (shard_index_static is not None and num_shards_static is not None and
shard_index_static >= num_shards_static):
raise ValueError("shard_index must be < num_shards; %s is not < %s" %
(shard_index_static, num_shards_static))
def filter_fn(elem_index, _):
mod_result = math_ops.mod(elem_index, num_shards)
return math_ops.equal(mod_result, shard_index)
def _apply_fn(dataset):
# pylint: disable=protected-access
return dataset._enumerate().filter(filter_fn).map(lambda _, elem: elem)
return _apply_fn

View File

@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import gen_stateless_random_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@ -82,6 +83,11 @@ class _ParallelInterleaveDataset(dataset_ops.UnaryDataset):
return "tf.data.experimental.parallel_interleave()"
@deprecation.deprecated(
None,
"Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, "
"num_parallel_calls=tf.data.experimental.AUTOTUNE)` instead. If sloppy "
"execution is desired, use `tf.data.Options.experimental_determinstic`.")
@tf_export("data.experimental.parallel_interleave")
def parallel_interleave(map_func,
cycle_length,

View File

@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@ -50,6 +51,11 @@ class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset):
variant_tensor)
@deprecation.deprecated(
None,
"Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by "
"`tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take "
"care of using the fused implementation.")
@tf_export("data.experimental.shuffle_and_repeat")
def shuffle_and_repeat(buffer_size, count=None, seed=None):
"""Shuffles and repeats a Dataset returning a new permutation for each epoch.

View File

@ -26,7 +26,6 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:filter_for_shard_ops",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/experimental/ops:stats_options",
"//tensorflow/python/data/experimental/ops:threading_options",

View File

@ -368,7 +368,6 @@ py_library(
srcs = ["input_ops.py"],
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/experimental/ops:filter_for_shard_ops",
"//tensorflow/python/data/util:nest",
],
)

View File

@ -40,10 +40,9 @@ def auto_shard_dataset(dataset, num_shards, index):
dataset: A `tf.data.Dataset` instance, typically the result of a bunch of
dataset transformations.
num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
shards operating in parallel. Same usage as in
`tf.data.experimental.filter_for_shard`.
shards operating in parallel. Same usage as in `tf.data.Dataset.shard`.
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Same usage as in `Dataset.shard`.
Same usage as in `tf.data.Dataset.shard`.
Returns:
A modified `Dataset` obtained by updating the pipeline sharded by the

View File

@ -112,10 +112,6 @@ tf_module {
name: "enumerate_dataset"
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
}
member_method {
name: "filter_for_shard"
argspec: "args=[\'num_shards\', \'shard_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_next_as_optional"
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"

View File

@ -112,10 +112,6 @@ tf_module {
name: "enumerate_dataset"
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
}
member_method {
name: "filter_for_shard"
argspec: "args=[\'num_shards\', \'shard_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_next_as_optional"
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"