Consolidate DistributeOptions.auto_shard into DistributeOptions.auto_shard_policy.

PiperOrigin-RevId: 275930249
Change-Id: Icccbf6530cd4477153c625ef22f4f1e560d86088
This commit is contained in:
Frank Chen 2019-10-21 14:56:25 -07:00 committed by TensorFlower Gardener
parent 54343f1433
commit 5f7e805916
11 changed files with 49 additions and 30 deletions

View File

@ -386,7 +386,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
AutoShardPolicy policy, GraphDef* output) {
if (num_workers == 1 && index == 0) {
if (policy == AutoShardPolicy::OFF || (num_workers == 1 && index == 0)) {
return Status::OK();
}
@ -407,6 +407,9 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
// occurences from randomness from before that point in the graph (e.g. things
// like ShuffleDataset) to ensure that `shard` returns a sensible result.
switch (policy) {
case AutoShardPolicy::OFF:
return Status::OK();
case AutoShardPolicy::FILE:
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*sink_node, num_workers, index,
&flib, &graph, &nodes_to_delete));
@ -458,7 +461,8 @@ Status AutoShard::Init(
auto_shard_policy_ =
AutoShardPolicy(config->parameter_map().at(kAutoShardPolicyAttrName).i());
if (auto_shard_policy_ != AutoShardPolicy::AUTO &&
if (auto_shard_policy_ != AutoShardPolicy::OFF &&
auto_shard_policy_ != AutoShardPolicy::AUTO &&
auto_shard_policy_ != AutoShardPolicy::DATA &&
auto_shard_policy_ != AutoShardPolicy::FILE) {
return errors::InvalidArgument(kAutoShardPolicyAttrName, " is invalid.");

View File

@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
enum class AutoShardPolicy { AUTO = 0, FILE = 1, DATA = 2 };
enum class AutoShardPolicy { OFF = -1, AUTO = 0, FILE = 1, DATA = 2 };
// AutoShard takes a Dataset graph and tries to insert a shard node
// automatically before a ReaderDataset (e.g. a CSVDataset or a TFRecordDataset)

View File

@ -253,6 +253,24 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
]
self.assertDatasetProduces(dataset, expected)
@combinations.generate(test_base.default_test_combinations())
def testAutoshardPolicyOff(self):
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = (
distribute_options.AutoShardPolicy.OFF)
dataset = core_readers._TFRecordDataset(self.test_filenames)
dataset = dataset.with_options(options)
dataset = distribute._AutoShardDataset(dataset, 5, 0)
# Should return every record in every file since autosharding is turned off.
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in range(0, 10)
for r in range(0, 10)
]
self.assertDatasetProduces(dataset, expected)
@combinations.generate(test_base.default_test_combinations())
def testFileShardingWithoutReaderDatasetOp(self):
options = dataset_ops.Options()

View File

@ -30,6 +30,7 @@ class AutoShardPolicy(enum.IntEnum):
Please see the DistributeOptions.auto_shard_policy documentation for more
information on each type of autosharding.
"""
OFF = -1
AUTO = 0
FILE = 1
DATA = 2
@ -45,22 +46,11 @@ class DistributeOptions(options.OptionsBase):
```python
options = tf.data.Options()
options.experimental_distribute.auto_shard = False
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
dataset = dataset.with_options(options)
```
"""
auto_shard = options.create_option(
name="auto_shard",
ty=bool,
docstring="Whether the dataset should be automatically sharded when "
"processed in a distributed fashion. This is applicable when using Keras "
"with multi-worker/TPU distribution strategy, and by "
"using strategy.experimental_distribute_dataset(). You can control the "
"behavior of the auto sharder via the `auto_shard_policy` option. In "
"other cases, this option does nothing. If None, defaults to True.",
default_factory=lambda: True)
auto_shard_policy = options.create_option(
name="auto_shard_policy",
ty=AutoShardPolicy,
@ -70,10 +60,12 @@ class DistributeOptions(options.OptionsBase):
"to shard for at least one file per worker, we will error out. When this "
"option is selected, make sure that you have enough files so that each "
"worker gets at least one file. There will be a runtime error thrown if "
"there are insufficient files."
"there are insufficient files. "
"If this is set to DATA, then we will shard by elements produced by the "
"dataset, and each worker will process the whole dataset and discard the "
"portion that is not for itself. "
"If this is set to OFF, then we will not autoshard, and each worker will "
"receive a copy of the full dataset. "
"This option is set to AUTO by default, AUTO will attempt to first shard "
"by FILE, and fall back to sharding by DATA if we cannot find a set of "
"files to shard.",

View File

@ -25,6 +25,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
@ -633,11 +634,11 @@ class DistributedIteratorMultiWorkerTest(
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
autoshard=[True, False]))
auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF]))
def testAutoshardingOption(self, input_type, api_type, iteration_type,
autoshard):
auto_shard_policy):
ds_option = dataset_ops.Options()
ds_option.experimental_distribute.auto_shard = autoshard
ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy
if tf2.enabled():
dataset_fn = (
lambda _: dataset_ops.DatasetV2.range(4).with_options(ds_option))
@ -653,7 +654,7 @@ class DistributedIteratorMultiWorkerTest(
["/job:worker/task:0", "/job:worker/task:1"], 1))
worker_devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess:
if autoshard:
if auto_shard_policy == AutoShardPolicy.AUTO:
expected_values = [[0, 1], [2, 3]]
else:
expected_values = [[0, 0], [1, 1], [2, 2], [3, 3]]

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import traverse
from tensorflow.python.framework import op_def_registry
@ -42,7 +43,8 @@ def auto_shard_dataset(dataset, num_shards, index):
files. The input dataset will be returned if we cannot automatically
determine a good way to shard the input dataset.
"""
if dataset.options().experimental_distribute.auto_shard:
if (dataset.options().experimental_distribute.auto_shard_policy !=
AutoShardPolicy.OFF):
if isinstance(dataset, dataset_ops.DatasetV1):
return distribute._AutoShardDatasetV1(dataset, num_shards, index)
else:

View File

@ -31,6 +31,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
@ -1647,7 +1648,8 @@ def infer_steps_for_dataset(model,
"""
assert isinstance(dataset, dataset_ops.DatasetV2)
if (model._in_multi_worker_mode() and
dataset.options().experimental_distribute.auto_shard):
(dataset.options().experimental_distribute.auto_shard_policy !=
AutoShardPolicy.OFF)):
# If the dataset would be auto-sharded, we should not infer a local
# steps_per_epoch due to the possible inbalanced sharding between workers.
return None

View File

@ -13,4 +13,8 @@ tf_class {
name: "FILE"
mtype: "<enum \'AutoShardPolicy\'>"
}
member {
name: "OFF"
mtype: "<enum \'AutoShardPolicy\'>"
}
}

View File

@ -3,10 +3,6 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.distribute_options.DistributeOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "auto_shard"
mtype: "<type \'property\'>"
}
member {
name: "auto_shard_policy"
mtype: "<type \'property\'>"

View File

@ -13,4 +13,8 @@ tf_class {
name: "FILE"
mtype: "<enum \'AutoShardPolicy\'>"
}
member {
name: "OFF"
mtype: "<enum \'AutoShardPolicy\'>"
}
}

View File

@ -3,10 +3,6 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.distribute_options.DistributeOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "auto_shard"
mtype: "<type \'property\'>"
}
member {
name: "auto_shard_policy"
mtype: "<type \'property\'>"