[tf.data]: Deprecate dataset.shard.

In preparation for a distributed-aware tf.data system, and TF 2.0, this CL
deprecates the legacy `.shard` method on tf.data.Dataset objects. The old
behavior is preserved under an experimental `filter_for_shard` dataset
transformation, as well as in the V1 API (albeit now marked with a deprecation
warning).

PiperOrigin-RevId: 223362539
This commit is contained in:
Brennan Saeta 2018-11-29 09:53:22 -08:00 committed by TensorFlower Gardener
parent 6afcfdfb75
commit f6e8f7a1fb
16 changed files with 194 additions and 106 deletions

View File

@ -39,6 +39,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@copy_to_device
@@dense_to_sparse_batch
@@enumerate_dataset
@@filter_for_shard
@@get_next_as_optional
@@get_single_element
@@group_by_reducer
@ -74,6 +75,7 @@ from tensorflow.python.data.experimental.ops.batching import unbatch
from tensorflow.python.data.experimental.ops.counter import Counter
from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
from tensorflow.python.data.experimental.ops.filter_for_shard_ops import filter_for_shard
from tensorflow.python.data.experimental.ops.get_single_element import get_single_element
from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_length
from tensorflow.python.data.experimental.ops.grouping import group_by_reducer

View File

@ -139,6 +139,18 @@ py_library(
],
)
py_library(
name = "filter_for_shard_ops",
srcs = ["filter_for_shard_ops.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
"//tensorflow/python:ops",
"//tensorflow/python:tensor_util",
],
)
py_library(
name = "error_ops",
srcs = ["error_ops.py"],
@ -403,6 +415,7 @@ py_library(
":counter",
":enumerate_ops",
":error_ops",
":filter_for_shard_ops",
":get_single_element",
":grouping",
":indexed_dataset_ops",

View File

@ -0,0 +1,106 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Naive shard dataset transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.filter_for_shard")
def filter_for_shard(num_shards, shard_index):
"""Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
This dataset operator is very useful when running distributed training, as
it allows each worker to read a unique subset.
When reading a single input file, you can skip elements as follows:
```python
d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.apply(tf.data.experimental.naive_shard(FLAGS.num_workers,
FLAGS.worker_index))
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
```
Important caveats:
- Be sure to shard before you use any randomizing operator (such as
shuffle).
- Generally it is best if the shard operator is used early in the dataset
pipeline. For example, when reading from a set of TFRecord files, shard
before converting the dataset to input samples. This avoids reading every
file on every worker. The following is an example of an efficient
sharding strategy within a complete pipeline:
```python
d = Dataset.list_files(FLAGS.pattern)
d = d.apply(tf.data.experimental.naive_shard(FLAGS.num_workers,
FLAGS.worker_index))
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
cycle_length=FLAGS.num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
```
Args:
num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
shards operating in parallel.
shard_index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
Raises:
ValueError: if `num_shards` or `shard_index` are illegal values. Note: error
checking is done on a best-effort basis, and errors aren't guaranteed to
be caught upon dataset creation. (e.g. providing in a placeholder tensor
bypasses the early checking, and will instead result in an error during
a session.run call.)
"""
num_shards = ops.convert_to_tensor(
num_shards, name="num_shards", dtype=dtypes.int64)
num_shards_static = tensor_util.constant_value(num_shards)
shard_index = ops.convert_to_tensor(shard_index, name="shard_index",
dtype=dtypes.int64)
shard_index_static = tensor_util.constant_value(shard_index)
if num_shards_static is not None and num_shards_static < 1:
raise ValueError("num_shards must be >= 1; got: %s" % num_shards_static)
if shard_index_static is not None and shard_index_static < 0:
raise ValueError("shard_index must be >= 0; got: %s" % shard_index_static)
if (shard_index_static is not None and num_shards_static is not None and
shard_index_static >= num_shards_static):
raise ValueError("shard_index must be < num_shards; %s is not < %s" %
(shard_index_static, num_shards_static))
def filter_fn(elem_index, _):
mod_result = math_ops.mod(elem_index, num_shards)
return math_ops.equal(mod_result, shard_index)
def _apply_fn(dataset):
# pylint: disable=protected-access
return dataset._enumerate().filter(filter_fn).map(lambda _, elem: elem)
return _apply_fn

View File

@ -26,6 +26,7 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:filter_for_shard_ops",
"//tensorflow/python/data/experimental/ops:stats_options",
"//tensorflow/python/data/experimental/ops:threading_options",
"//tensorflow/python/data/util:nest",

View File

@ -26,6 +26,7 @@ import numpy as np
import six
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import filter_for_shard_ops
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.experimental.ops import threading_options
from tensorflow.python.data.ops import iterator_ops
@ -821,78 +822,6 @@ class DatasetV2(object):
"""
return SkipDataset(self, count)
def shard(self, num_shards, index):
"""Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
This dataset operator is very useful when running distributed training, as
it allows each worker to read a unique subset.
When reading a single input file, you can skip elements as follows:
```python
d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
```
Important caveats:
- Be sure to shard before you use any randomizing operator (such as
shuffle).
- Generally it is best if the shard operator is used early in the dataset
pipeline. For example, when reading from a set of TFRecord files, shard
before converting the dataset to input samples. This avoids reading every
file on every worker. The following is an example of an efficient
sharding strategy within a complete pipeline:
```python
d = Dataset.list_files(FLAGS.pattern)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
cycle_length=FLAGS.num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
```
Args:
num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
shards operating in parallel.
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Returns:
Dataset: A `Dataset`.
Raises:
ValueError: if `num_shards` or `index` are illegal values. Note: error
checking is done on a best-effort basis, and errors aren't guaranteed
to be caught upon dataset creation. (e.g. providing in a placeholder
tensor bypasses the early checking, and will instead result in an error
during a session.run call.)
"""
num_shards = ops.convert_to_tensor(
num_shards, name="num_shards", dtype=dtypes.int64)
num_shards_static = tensor_util.constant_value(num_shards)
index = ops.convert_to_tensor(index, name="index", dtype=dtypes.int64)
index_static = tensor_util.constant_value(index)
if num_shards_static is not None and num_shards_static < 1:
raise ValueError("num_shards must be >= 1; got: %s" % num_shards_static)
if index_static is not None and index_static < 0:
raise ValueError("index must be >= 0; got: %s" % index_static)
if (index_static is not None and num_shards_static is not None and
index_static >= num_shards_static):
raise ValueError("index must be <= num_shards; %s is not < %s" %
(index_static, num_shards_static))
def filter_fn(elem_index, _):
mod_result = math_ops.mod(elem_index, num_shards)
return math_ops.equal(mod_result, index)
return self._enumerate().filter(filter_fn).map(lambda _, elem: elem)
def batch(self, batch_size, drop_remainder=False):
"""Combines consecutive elements of this dataset into batches.
@ -1486,9 +1415,60 @@ class DatasetV1(DatasetV2):
def skip(self, count):
return DatasetV1Adapter(super(DatasetV1, self).skip(count))
@functools.wraps(DatasetV2.shard)
@deprecation.deprecated(
None, "Use `dataset.apply(tf.data.experimental.filter_for_shard(...))`.")
def shard(self, num_shards, index):
return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index))
"""Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
This dataset operator is very useful when running distributed training, as
it allows each worker to read a unique subset.
When reading a single input file, you can skip elements as follows:
```python
d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
```
Important caveats:
- Be sure to shard before you use any randomizing operator (such as
shuffle).
- Generally it is best if the shard operator is used early in the dataset
pipeline. For example, when reading from a set of TFRecord files, shard
before converting the dataset to input samples. This avoids reading every
file on every worker. The following is an example of an efficient
sharding strategy within a complete pipeline:
```python
d = Dataset.list_files(FLAGS.pattern)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
cycle_length=FLAGS.num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
```
Args:
num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
shards operating in parallel.
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Returns:
Dataset: A `Dataset`.
Raises:
ValueError: if `num_shards` or `index` are illegal values. Note: error
checking is done on a best-effort basis, and errors aren't guaranteed
to be caught upon dataset creation. (e.g. providing in a placeholder
tensor bypasses the early checking, and will instead result in an error
during a session.run call.)
"""
return self.apply(filter_for_shard_ops.filter_for_shard(num_shards, index))
@functools.wraps(DatasetV2.batch)
def batch(self, batch_size, drop_remainder=False):

View File

@ -245,6 +245,7 @@ py_library(
srcs = ["input_ops.py"],
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/experimental/ops:filter_for_shard_ops",
"//tensorflow/python/data/util:nest",
],
)

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import filter_for_shard_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
@ -41,7 +42,8 @@ def auto_shard_dataset(dataset, num_shards, index):
dataset: A `tf.data.Dataset` instance, typically the result of a bunch of
dataset transformations.
num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
shards operating in parallel. Same usage as in `Dataset.shard`.
shards operating in parallel. Same usage as in
`tf.data.experimental.filter_for_shard`.
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Same usage as in `Dataset.shard`.
@ -74,9 +76,11 @@ def auto_shard_dataset(dataset, num_shards, index):
# constructor. Eventually we will change all cases to clone datasets
# instead of updating in-place.
return dataset._clone(
filenames=dataset._filenames.shard(num_shards, index))
filenames=dataset._filenames.apply(
filter_for_shard_ops.filter_for_shard(num_shards, index)))
elif isinstance(dataset, dataset_ops.RangeDataset):
return dataset.shard(num_shards, index)
return dataset.apply(
filter_for_shard_ops.filter_for_shard(num_shards, index))
elif hasattr(dataset, "_map_func"):
# TODO(priyag): Make this check more robust by enforcing some common
# property on all map/flatmap/interleave datasets.
@ -142,6 +146,7 @@ def auto_shard_dataset(dataset, num_shards, index):
# TODO(priyag): This will shard the filenames before any shuffling of the
# filename dataset. It might be desirable to shard after shuffling
# filenames? If so, how do we achieve that?
return dataset.shard(num_shards, index)
return dataset.apply(
filter_for_shard_ops.filter_for_shard(num_shards, index))
return _auto_shard_impl(dataset=dataset, found_reader_op=False)

View File

@ -68,6 +68,10 @@ tf_module {
name: "enumerate_dataset"
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
}
member_method {
name: "filter_for_shard"
argspec: "args=[\'num_shards\', \'shard_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_next_as_optional"
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"

View File

@ -97,10 +97,6 @@ tf_class {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -100,10 +100,6 @@ tf_class {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -99,10 +99,6 @@ tf_class {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -100,10 +100,6 @@ tf_class {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -100,10 +100,6 @@ tf_class {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -100,10 +100,6 @@ tf_class {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -100,10 +100,6 @@ tf_class {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -68,6 +68,10 @@ tf_module {
name: "enumerate_dataset"
argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
}
member_method {
name: "filter_for_shard"
argspec: "args=[\'num_shards\', \'shard_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_next_as_optional"
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"