Consolidate DistributeOptions.auto_shard into DistributeOptions.auto_shard_policy.
PiperOrigin-RevId: 275930249 Change-Id: Icccbf6530cd4477153c625ef22f4f1e560d86088
This commit is contained in:
parent
54343f1433
commit
5f7e805916
@ -386,7 +386,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
|
|||||||
|
|
||||||
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
|
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
|
||||||
AutoShardPolicy policy, GraphDef* output) {
|
AutoShardPolicy policy, GraphDef* output) {
|
||||||
if (num_workers == 1 && index == 0) {
|
if (policy == AutoShardPolicy::OFF || (num_workers == 1 && index == 0)) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -407,6 +407,9 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
|
|||||||
// occurences from randomness from before that point in the graph (e.g. things
|
// occurences from randomness from before that point in the graph (e.g. things
|
||||||
// like ShuffleDataset) to ensure that `shard` returns a sensible result.
|
// like ShuffleDataset) to ensure that `shard` returns a sensible result.
|
||||||
switch (policy) {
|
switch (policy) {
|
||||||
|
case AutoShardPolicy::OFF:
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
case AutoShardPolicy::FILE:
|
case AutoShardPolicy::FILE:
|
||||||
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*sink_node, num_workers, index,
|
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*sink_node, num_workers, index,
|
||||||
&flib, &graph, &nodes_to_delete));
|
&flib, &graph, &nodes_to_delete));
|
||||||
@ -458,7 +461,8 @@ Status AutoShard::Init(
|
|||||||
auto_shard_policy_ =
|
auto_shard_policy_ =
|
||||||
AutoShardPolicy(config->parameter_map().at(kAutoShardPolicyAttrName).i());
|
AutoShardPolicy(config->parameter_map().at(kAutoShardPolicyAttrName).i());
|
||||||
|
|
||||||
if (auto_shard_policy_ != AutoShardPolicy::AUTO &&
|
if (auto_shard_policy_ != AutoShardPolicy::OFF &&
|
||||||
|
auto_shard_policy_ != AutoShardPolicy::AUTO &&
|
||||||
auto_shard_policy_ != AutoShardPolicy::DATA &&
|
auto_shard_policy_ != AutoShardPolicy::DATA &&
|
||||||
auto_shard_policy_ != AutoShardPolicy::FILE) {
|
auto_shard_policy_ != AutoShardPolicy::FILE) {
|
||||||
return errors::InvalidArgument(kAutoShardPolicyAttrName, " is invalid.");
|
return errors::InvalidArgument(kAutoShardPolicyAttrName, " is invalid.");
|
||||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
enum class AutoShardPolicy { AUTO = 0, FILE = 1, DATA = 2 };
|
enum class AutoShardPolicy { OFF = -1, AUTO = 0, FILE = 1, DATA = 2 };
|
||||||
|
|
||||||
// AutoShard takes a Dataset graph and tries to insert a shard node
|
// AutoShard takes a Dataset graph and tries to insert a shard node
|
||||||
// automatically before a ReaderDataset (e.g. a CSVDataset or a TFRecordDataset)
|
// automatically before a ReaderDataset (e.g. a CSVDataset or a TFRecordDataset)
|
||||||
|
@ -253,6 +253,24 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
|||||||
]
|
]
|
||||||
self.assertDatasetProduces(dataset, expected)
|
self.assertDatasetProduces(dataset, expected)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testAutoshardPolicyOff(self):
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_distribute.auto_shard_policy = (
|
||||||
|
distribute_options.AutoShardPolicy.OFF)
|
||||||
|
|
||||||
|
dataset = core_readers._TFRecordDataset(self.test_filenames)
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
dataset = distribute._AutoShardDataset(dataset, 5, 0)
|
||||||
|
|
||||||
|
# Should return every record in every file since autosharding is turned off.
|
||||||
|
expected = [
|
||||||
|
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||||
|
for f in range(0, 10)
|
||||||
|
for r in range(0, 10)
|
||||||
|
]
|
||||||
|
self.assertDatasetProduces(dataset, expected)
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testFileShardingWithoutReaderDatasetOp(self):
|
def testFileShardingWithoutReaderDatasetOp(self):
|
||||||
options = dataset_ops.Options()
|
options = dataset_ops.Options()
|
||||||
|
@ -30,6 +30,7 @@ class AutoShardPolicy(enum.IntEnum):
|
|||||||
Please see the DistributeOptions.auto_shard_policy documentation for more
|
Please see the DistributeOptions.auto_shard_policy documentation for more
|
||||||
information on each type of autosharding.
|
information on each type of autosharding.
|
||||||
"""
|
"""
|
||||||
|
OFF = -1
|
||||||
AUTO = 0
|
AUTO = 0
|
||||||
FILE = 1
|
FILE = 1
|
||||||
DATA = 2
|
DATA = 2
|
||||||
@ -45,22 +46,11 @@ class DistributeOptions(options.OptionsBase):
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
options = tf.data.Options()
|
options = tf.data.Options()
|
||||||
options.experimental_distribute.auto_shard = False
|
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
|
||||||
dataset = dataset.with_options(options)
|
dataset = dataset.with_options(options)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
auto_shard = options.create_option(
|
|
||||||
name="auto_shard",
|
|
||||||
ty=bool,
|
|
||||||
docstring="Whether the dataset should be automatically sharded when "
|
|
||||||
"processed in a distributed fashion. This is applicable when using Keras "
|
|
||||||
"with multi-worker/TPU distribution strategy, and by "
|
|
||||||
"using strategy.experimental_distribute_dataset(). You can control the "
|
|
||||||
"behavior of the auto sharder via the `auto_shard_policy` option. In "
|
|
||||||
"other cases, this option does nothing. If None, defaults to True.",
|
|
||||||
default_factory=lambda: True)
|
|
||||||
|
|
||||||
auto_shard_policy = options.create_option(
|
auto_shard_policy = options.create_option(
|
||||||
name="auto_shard_policy",
|
name="auto_shard_policy",
|
||||||
ty=AutoShardPolicy,
|
ty=AutoShardPolicy,
|
||||||
@ -70,10 +60,12 @@ class DistributeOptions(options.OptionsBase):
|
|||||||
"to shard for at least one file per worker, we will error out. When this "
|
"to shard for at least one file per worker, we will error out. When this "
|
||||||
"option is selected, make sure that you have enough files so that each "
|
"option is selected, make sure that you have enough files so that each "
|
||||||
"worker gets at least one file. There will be a runtime error thrown if "
|
"worker gets at least one file. There will be a runtime error thrown if "
|
||||||
"there are insufficient files."
|
"there are insufficient files. "
|
||||||
"If this is set to DATA, then we will shard by elements produced by the "
|
"If this is set to DATA, then we will shard by elements produced by the "
|
||||||
"dataset, and each worker will process the whole dataset and discard the "
|
"dataset, and each worker will process the whole dataset and discard the "
|
||||||
"portion that is not for itself. "
|
"portion that is not for itself. "
|
||||||
|
"If this is set to OFF, then we will not autoshard, and each worker will "
|
||||||
|
"receive a copy of the full dataset. "
|
||||||
"This option is set to AUTO by default, AUTO will attempt to first shard "
|
"This option is set to AUTO by default, AUTO will attempt to first shard "
|
||||||
"by FILE, and fall back to sharding by DATA if we cannot find a set of "
|
"by FILE, and fall back to sharding by DATA if we cannot find a set of "
|
||||||
"files to shard.",
|
"files to shard.",
|
||||||
|
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
|
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
@ -633,11 +634,11 @@ class DistributedIteratorMultiWorkerTest(
|
|||||||
input_type=["dataset"],
|
input_type=["dataset"],
|
||||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||||
iteration_type=["get_next", "for_loop"],
|
iteration_type=["get_next", "for_loop"],
|
||||||
autoshard=[True, False]))
|
auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF]))
|
||||||
def testAutoshardingOption(self, input_type, api_type, iteration_type,
|
def testAutoshardingOption(self, input_type, api_type, iteration_type,
|
||||||
autoshard):
|
auto_shard_policy):
|
||||||
ds_option = dataset_ops.Options()
|
ds_option = dataset_ops.Options()
|
||||||
ds_option.experimental_distribute.auto_shard = autoshard
|
ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy
|
||||||
if tf2.enabled():
|
if tf2.enabled():
|
||||||
dataset_fn = (
|
dataset_fn = (
|
||||||
lambda _: dataset_ops.DatasetV2.range(4).with_options(ds_option))
|
lambda _: dataset_ops.DatasetV2.range(4).with_options(ds_option))
|
||||||
@ -653,7 +654,7 @@ class DistributedIteratorMultiWorkerTest(
|
|||||||
["/job:worker/task:0", "/job:worker/task:1"], 1))
|
["/job:worker/task:0", "/job:worker/task:1"], 1))
|
||||||
worker_devices = self._cpu_devices()
|
worker_devices = self._cpu_devices()
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
if autoshard:
|
if auto_shard_policy == AutoShardPolicy.AUTO:
|
||||||
expected_values = [[0, 1], [2, 3]]
|
expected_values = [[0, 1], [2, 3]]
|
||||||
else:
|
else:
|
||||||
expected_values = [[0, 0], [1, 1], [2, 2], [3, 3]]
|
expected_values = [[0, 0], [1, 1], [2, 2], [3, 3]]
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import distribute
|
from tensorflow.python.data.experimental.ops import distribute
|
||||||
|
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.util import traverse
|
from tensorflow.python.data.util import traverse
|
||||||
from tensorflow.python.framework import op_def_registry
|
from tensorflow.python.framework import op_def_registry
|
||||||
@ -42,7 +43,8 @@ def auto_shard_dataset(dataset, num_shards, index):
|
|||||||
files. The input dataset will be returned if we cannot automatically
|
files. The input dataset will be returned if we cannot automatically
|
||||||
determine a good way to shard the input dataset.
|
determine a good way to shard the input dataset.
|
||||||
"""
|
"""
|
||||||
if dataset.options().experimental_distribute.auto_shard:
|
if (dataset.options().experimental_distribute.auto_shard_policy !=
|
||||||
|
AutoShardPolicy.OFF):
|
||||||
if isinstance(dataset, dataset_ops.DatasetV1):
|
if isinstance(dataset, dataset_ops.DatasetV1):
|
||||||
return distribute._AutoShardDatasetV1(dataset, num_shards, index)
|
return distribute._AutoShardDatasetV1(dataset, num_shards, index)
|
||||||
else:
|
else:
|
||||||
|
@ -31,6 +31,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.data.experimental.ops import cardinality
|
from tensorflow.python.data.experimental.ops import cardinality
|
||||||
|
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
from tensorflow.python.data.ops import readers
|
from tensorflow.python.data.ops import readers
|
||||||
@ -1647,7 +1648,8 @@ def infer_steps_for_dataset(model,
|
|||||||
"""
|
"""
|
||||||
assert isinstance(dataset, dataset_ops.DatasetV2)
|
assert isinstance(dataset, dataset_ops.DatasetV2)
|
||||||
if (model._in_multi_worker_mode() and
|
if (model._in_multi_worker_mode() and
|
||||||
dataset.options().experimental_distribute.auto_shard):
|
(dataset.options().experimental_distribute.auto_shard_policy !=
|
||||||
|
AutoShardPolicy.OFF)):
|
||||||
# If the dataset would be auto-sharded, we should not infer a local
|
# If the dataset would be auto-sharded, we should not infer a local
|
||||||
# steps_per_epoch due to the possible inbalanced sharding between workers.
|
# steps_per_epoch due to the possible inbalanced sharding between workers.
|
||||||
return None
|
return None
|
||||||
|
@ -13,4 +13,8 @@ tf_class {
|
|||||||
name: "FILE"
|
name: "FILE"
|
||||||
mtype: "<enum \'AutoShardPolicy\'>"
|
mtype: "<enum \'AutoShardPolicy\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "OFF"
|
||||||
|
mtype: "<enum \'AutoShardPolicy\'>"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,6 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.distribute_options.DistributeOptions\'>"
|
is_instance: "<class \'tensorflow.python.data.experimental.ops.distribute_options.DistributeOptions\'>"
|
||||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
|
||||||
name: "auto_shard"
|
|
||||||
mtype: "<type \'property\'>"
|
|
||||||
}
|
|
||||||
member {
|
member {
|
||||||
name: "auto_shard_policy"
|
name: "auto_shard_policy"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -13,4 +13,8 @@ tf_class {
|
|||||||
name: "FILE"
|
name: "FILE"
|
||||||
mtype: "<enum \'AutoShardPolicy\'>"
|
mtype: "<enum \'AutoShardPolicy\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "OFF"
|
||||||
|
mtype: "<enum \'AutoShardPolicy\'>"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,6 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.distribute_options.DistributeOptions\'>"
|
is_instance: "<class \'tensorflow.python.data.experimental.ops.distribute_options.DistributeOptions\'>"
|
||||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
|
||||||
name: "auto_shard"
|
|
||||||
mtype: "<type \'property\'>"
|
|
||||||
}
|
|
||||||
member {
|
member {
|
||||||
name: "auto_shard_policy"
|
name: "auto_shard_policy"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
Loading…
Reference in New Issue
Block a user