Consolidate DistributeOptions.auto_shard into DistributeOptions.auto_shard_policy.
PiperOrigin-RevId: 275930249 Change-Id: Icccbf6530cd4477153c625ef22f4f1e560d86088
This commit is contained in:
parent
54343f1433
commit
5f7e805916
tensorflow
core/grappler/optimizers/data
python
data/experimental
distribute
keras/engine
tools/api/golden
@ -386,7 +386,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
|
||||
|
||||
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
|
||||
AutoShardPolicy policy, GraphDef* output) {
|
||||
if (num_workers == 1 && index == 0) {
|
||||
if (policy == AutoShardPolicy::OFF || (num_workers == 1 && index == 0)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -407,6 +407,9 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
|
||||
// occurences from randomness from before that point in the graph (e.g. things
|
||||
// like ShuffleDataset) to ensure that `shard` returns a sensible result.
|
||||
switch (policy) {
|
||||
case AutoShardPolicy::OFF:
|
||||
return Status::OK();
|
||||
|
||||
case AutoShardPolicy::FILE:
|
||||
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*sink_node, num_workers, index,
|
||||
&flib, &graph, &nodes_to_delete));
|
||||
@ -458,7 +461,8 @@ Status AutoShard::Init(
|
||||
auto_shard_policy_ =
|
||||
AutoShardPolicy(config->parameter_map().at(kAutoShardPolicyAttrName).i());
|
||||
|
||||
if (auto_shard_policy_ != AutoShardPolicy::AUTO &&
|
||||
if (auto_shard_policy_ != AutoShardPolicy::OFF &&
|
||||
auto_shard_policy_ != AutoShardPolicy::AUTO &&
|
||||
auto_shard_policy_ != AutoShardPolicy::DATA &&
|
||||
auto_shard_policy_ != AutoShardPolicy::FILE) {
|
||||
return errors::InvalidArgument(kAutoShardPolicyAttrName, " is invalid.");
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
enum class AutoShardPolicy { AUTO = 0, FILE = 1, DATA = 2 };
|
||||
enum class AutoShardPolicy { OFF = -1, AUTO = 0, FILE = 1, DATA = 2 };
|
||||
|
||||
// AutoShard takes a Dataset graph and tries to insert a shard node
|
||||
// automatically before a ReaderDataset (e.g. a CSVDataset or a TFRecordDataset)
|
||||
|
@ -253,6 +253,24 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
]
|
||||
self.assertDatasetProduces(dataset, expected)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAutoshardPolicyOff(self):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute.auto_shard_policy = (
|
||||
distribute_options.AutoShardPolicy.OFF)
|
||||
|
||||
dataset = core_readers._TFRecordDataset(self.test_filenames)
|
||||
dataset = dataset.with_options(options)
|
||||
dataset = distribute._AutoShardDataset(dataset, 5, 0)
|
||||
|
||||
# Should return every record in every file since autosharding is turned off.
|
||||
expected = [
|
||||
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||
for f in range(0, 10)
|
||||
for r in range(0, 10)
|
||||
]
|
||||
self.assertDatasetProduces(dataset, expected)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFileShardingWithoutReaderDatasetOp(self):
|
||||
options = dataset_ops.Options()
|
||||
|
@ -30,6 +30,7 @@ class AutoShardPolicy(enum.IntEnum):
|
||||
Please see the DistributeOptions.auto_shard_policy documentation for more
|
||||
information on each type of autosharding.
|
||||
"""
|
||||
OFF = -1
|
||||
AUTO = 0
|
||||
FILE = 1
|
||||
DATA = 2
|
||||
@ -45,22 +46,11 @@ class DistributeOptions(options.OptionsBase):
|
||||
|
||||
```python
|
||||
options = tf.data.Options()
|
||||
options.experimental_distribute.auto_shard = False
|
||||
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
|
||||
dataset = dataset.with_options(options)
|
||||
```
|
||||
"""
|
||||
|
||||
auto_shard = options.create_option(
|
||||
name="auto_shard",
|
||||
ty=bool,
|
||||
docstring="Whether the dataset should be automatically sharded when "
|
||||
"processed in a distributed fashion. This is applicable when using Keras "
|
||||
"with multi-worker/TPU distribution strategy, and by "
|
||||
"using strategy.experimental_distribute_dataset(). You can control the "
|
||||
"behavior of the auto sharder via the `auto_shard_policy` option. In "
|
||||
"other cases, this option does nothing. If None, defaults to True.",
|
||||
default_factory=lambda: True)
|
||||
|
||||
auto_shard_policy = options.create_option(
|
||||
name="auto_shard_policy",
|
||||
ty=AutoShardPolicy,
|
||||
@ -70,10 +60,12 @@ class DistributeOptions(options.OptionsBase):
|
||||
"to shard for at least one file per worker, we will error out. When this "
|
||||
"option is selected, make sure that you have enough files so that each "
|
||||
"worker gets at least one file. There will be a runtime error thrown if "
|
||||
"there are insufficient files."
|
||||
"there are insufficient files. "
|
||||
"If this is set to DATA, then we will shard by elements produced by the "
|
||||
"dataset, and each worker will process the whole dataset and discard the "
|
||||
"portion that is not for itself. "
|
||||
"If this is set to OFF, then we will not autoshard, and each worker will "
|
||||
"receive a copy of the full dataset. "
|
||||
"This option is set to AUTO by default, AUTO will attempt to first shard "
|
||||
"by FILE, and fall back to sharding by DATA if we cannot find a set of "
|
||||
"files to shard.",
|
||||
|
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
@ -633,11 +634,11 @@ class DistributedIteratorMultiWorkerTest(
|
||||
input_type=["dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
autoshard=[True, False]))
|
||||
auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF]))
|
||||
def testAutoshardingOption(self, input_type, api_type, iteration_type,
|
||||
autoshard):
|
||||
auto_shard_policy):
|
||||
ds_option = dataset_ops.Options()
|
||||
ds_option.experimental_distribute.auto_shard = autoshard
|
||||
ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy
|
||||
if tf2.enabled():
|
||||
dataset_fn = (
|
||||
lambda _: dataset_ops.DatasetV2.range(4).with_options(ds_option))
|
||||
@ -653,7 +654,7 @@ class DistributedIteratorMultiWorkerTest(
|
||||
["/job:worker/task:0", "/job:worker/task:1"], 1))
|
||||
worker_devices = self._cpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
if autoshard:
|
||||
if auto_shard_policy == AutoShardPolicy.AUTO:
|
||||
expected_values = [[0, 1], [2, 3]]
|
||||
else:
|
||||
expected_values = [[0, 0], [1, 1], [2, 2], [3, 3]]
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import traverse
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
@ -42,7 +43,8 @@ def auto_shard_dataset(dataset, num_shards, index):
|
||||
files. The input dataset will be returned if we cannot automatically
|
||||
determine a good way to shard the input dataset.
|
||||
"""
|
||||
if dataset.options().experimental_distribute.auto_shard:
|
||||
if (dataset.options().experimental_distribute.auto_shard_policy !=
|
||||
AutoShardPolicy.OFF):
|
||||
if isinstance(dataset, dataset_ops.DatasetV1):
|
||||
return distribute._AutoShardDatasetV1(dataset, num_shards, index)
|
||||
else:
|
||||
|
@ -31,6 +31,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
@ -1647,7 +1648,8 @@ def infer_steps_for_dataset(model,
|
||||
"""
|
||||
assert isinstance(dataset, dataset_ops.DatasetV2)
|
||||
if (model._in_multi_worker_mode() and
|
||||
dataset.options().experimental_distribute.auto_shard):
|
||||
(dataset.options().experimental_distribute.auto_shard_policy !=
|
||||
AutoShardPolicy.OFF)):
|
||||
# If the dataset would be auto-sharded, we should not infer a local
|
||||
# steps_per_epoch due to the possible inbalanced sharding between workers.
|
||||
return None
|
||||
|
@ -13,4 +13,8 @@ tf_class {
|
||||
name: "FILE"
|
||||
mtype: "<enum \'AutoShardPolicy\'>"
|
||||
}
|
||||
member {
|
||||
name: "OFF"
|
||||
mtype: "<enum \'AutoShardPolicy\'>"
|
||||
}
|
||||
}
|
||||
|
@ -3,10 +3,6 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.distribute_options.DistributeOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "auto_shard"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "auto_shard_policy"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -13,4 +13,8 @@ tf_class {
|
||||
name: "FILE"
|
||||
mtype: "<enum \'AutoShardPolicy\'>"
|
||||
}
|
||||
member {
|
||||
name: "OFF"
|
||||
mtype: "<enum \'AutoShardPolicy\'>"
|
||||
}
|
||||
}
|
||||
|
@ -3,10 +3,6 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.distribute_options.DistributeOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "auto_shard"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "auto_shard_policy"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user