STT-tensorflow/tensorflow/python/data/experimental/ops/distribute_options.py
Jiri Simsa ca9d421ddd [tf.data] This CL changes how the shuffle seed generator is managed, making it possible for the shuffle dataset to support both a) sharing of the seed generator across iterators and b) serialization. As a consequence, this CL enables reshuffling across iterations for tf.distribute and tf.data service use cases (which require both sharing of the seed generator across iterators and serialization support).
This CL in itself is a fairly large refactoring of the shuffle dataset implementation. Unifying the implementation of different op kernels for shuffle with fixed seeds, shuffle with pseudorandom seeds, and fused shuffle and repeat.

This CL also removes the `make_stateless` graph rewrite as it is no longer needed.

PiperOrigin-RevId: 308064029
Change-Id: I2f1d7916fe9958cf99d4e1b197da95c46b5d8b5f
2020-04-23 09:08:56 -07:00

86 lines
3.2 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for controlling distribution in `tf.data` pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.AutoShardPolicy")
class AutoShardPolicy(enum.IntEnum):
"""Represents the type of auto-sharding we enable.
Please see the DistributeOptions.auto_shard_policy documentation for more
information on each type of autosharding.
"""
OFF = -1
AUTO = 0
FILE = 1
DATA = 2
class ExternalStatePolicy(enum.Enum):
WARN = 0
IGNORE = 1
FAIL = 2
@tf_export("data.experimental.DistributeOptions")
class DistributeOptions(options.OptionsBase):
"""Represents options for distributed data processing.
You can set the distribution options of a dataset through the
`experimental_distribute` property of `tf.data.Options`; the property is
an instance of `tf.data.experimental.DistributeOptions`.
```python
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
dataset = dataset.with_options(options)
```
"""
auto_shard_policy = options.create_option(
name="auto_shard_policy",
ty=AutoShardPolicy,
docstring="The type of sharding that auto-shard should attempt. If this "
"is set to FILE, then we will attempt to shard by files (each worker "
"will get a set of files to process). If we cannot find a set of files "
"to shard for at least one file per worker, we will error out. When this "
"option is selected, make sure that you have enough files so that each "
"worker gets at least one file. There will be a runtime error thrown if "
"there are insufficient files. "
"If this is set to DATA, then we will shard by elements produced by the "
"dataset, and each worker will process the whole dataset and discard the "
"portion that is not for itself. "
"If this is set to OFF, then we will not autoshard, and each worker will "
"receive a copy of the full dataset. "
"This option is set to AUTO by default, AUTO will attempt to first shard "
"by FILE, and fall back to sharding by DATA if we cannot find a set of "
"files to shard.",
default_factory=lambda: AutoShardPolicy.AUTO)
num_devices = options.create_option(
name="num_devices",
ty=int,
docstring=
"The number of devices attached to this input pipeline. This will be "
"automatically set by MultiDeviceIterator.")