model.predict for MultiWorkerMirroredStrategy.

PiperOrigin-RevId: 335892290
Change-Id: I9202fcecae62ed78948c4c93d9f19d077c6928a7
This commit is contained in:
Xinyi Wang 2020-10-07 10:01:01 -07:00 committed by TensorFlower Gardener
parent 811248e122
commit 37995ed712
5 changed files with 210 additions and 33 deletions

View File

@ -337,6 +337,7 @@ distribute_py_test(
srcs = ["distribute_strategy_test.py"],
full_precision = True,
main = "distribute_strategy_test.py",
python_version = "PY3",
shard_count = 10,
tags = [
"multi_and_single_gpu",

View File

@ -17,15 +17,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.experimental.ops import writers
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import parameter_server_strategy
@ -36,6 +44,7 @@ from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import testing_utils
@ -50,6 +59,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import loss_reduction
from tensorflow.python.ops.ragged import ragged_tensor
@ -152,8 +162,7 @@ def batch_wrapper(dataset, batch_size, distribution, repeat=None):
dataset = dataset.repeat(repeat)
# TPUs currently require fully defined input shapes, drop_remainder ensures
# the input will have fully defined shapes.
if isinstance(distribution,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
if _is_tpu_strategy(distribution):
return dataset.batch(batch_size, drop_remainder=True)
else:
return dataset.batch(batch_size)
@ -237,11 +246,18 @@ strategies_minus_tpu = [
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
]
multi_worker_mirrored_strategies = [
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
]
tpu_strategies = [
strategy_combinations.tpu_strategy,
]
all_strategies = strategies_minus_tpu + tpu_strategies
all_strategies = (
strategies_minus_tpu + tpu_strategies + multi_worker_mirrored_strategies)
def strategy_minus_tpu_combinations():
@ -258,8 +274,14 @@ def tpu_strategy_combinations_graph_only():
return combinations.combine(distribution=tpu_strategies, mode=['graph'])
def multi_worker_strategy_combinations_eager_only():
return combinations.combine(
distribution=multi_worker_mirrored_strategies, mode=['eager'])
def all_strategy_combinations():
return strategy_minus_tpu_combinations() + tpu_strategy_combinations()
return strategy_minus_tpu_combinations() + tpu_strategy_combinations(
) + multi_worker_strategy_combinations_eager_only()
def all_strategy_minus_default_and_tpu_combinations():
@ -275,7 +297,8 @@ def all_strategy_minus_default_and_tpu_combinations():
def all_strategy_combinations_minus_default():
return (all_strategy_minus_default_and_tpu_combinations() +
tpu_strategy_combinations())
tpu_strategy_combinations() +
multi_worker_strategy_combinations_eager_only())
def strategy_and_optimizer_combinations():
@ -318,7 +341,21 @@ def strategy_and_optimizer_combinations():
optimizer_combinations.gradient_descent_optimizer_keras_v2_fn,
optimizer_combinations.rmsprop_optimizer_keras_v2_fn
])
return non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph
multi_worker_eager = combinations.combine(
distribution=multi_worker_mirrored_strategies,
mode=['eager'],
optimizer=[
optimizer_combinations.adadelta_optimizer_keras_v2_fn,
optimizer_combinations.adagrad_optimizer_keras_v2_fn,
optimizer_combinations.adam_optimizer_keras_v2_fn,
optimizer_combinations.adamax_optimizer_keras_v2_fn,
optimizer_combinations.gradient_descent_optimizer_keras_v2_fn,
optimizer_combinations.nadam_optimizer_keras_v2_fn,
optimizer_combinations.rmsprop_optimizer_keras_v2_fn,
optimizer_combinations.ftrl_optimizer_keras_v2_fn
])
return (non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph +
multi_worker_eager)
class BatchCountingCB(keras.callbacks.Callback):
@ -494,8 +531,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
if isinstance(distribution.extended,
parameter_server_strategy.ParameterServerStrategyExtended):
self.skipTest('b/152097775')
if isinstance(distribution,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
if _is_tpu_strategy(distribution):
policy_name = 'mixed_bfloat16'
else:
policy_name = 'mixed_float16'
@ -545,8 +581,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
parameter_server_strategy.ParameterServerStrategyExtended):
self.skipTest('b/152097775')
if isinstance(distribution,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
if _is_tpu_strategy(distribution):
policy_name = 'mixed_bfloat16'
else:
policy_name = 'mixed_float16'
@ -627,8 +662,9 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
@ds_combinations.generate(
combinations.combine(
distribution=strategies_minus_tpu,
mode=['graph', 'eager']))
distribution=strategies_minus_tpu, mode=['graph', 'eager']) +
combinations.combine(
distribution=multi_worker_mirrored_strategies, mode=['eager']))
def test_numpy_with_sample_weights(self, distribution):
with self.cached_session(), distribution.scope():
model = get_sample_weights_model()
@ -989,8 +1025,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
def test_fit_with_dictionary_in_the_dataset_b135161171(
self, distribution):
if isinstance(distribution,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
if _is_tpu_strategy(distribution):
self.skipTest('b/142805125')
def custom_loss(predict, label, weight):
@ -1069,19 +1104,56 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
self.assertAllClose(
predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4)
@ds_combinations.generate(all_strategy_combinations())
def test_predict_on_dataset_with_unknown_cardinality_without_steps(
self, distribution, mode):
# TODO(b/155867206): Investigate why this test occasionally segfaults on TPU
# in eager mode.
if mode == 'eager' and _is_tpu_strategy(distribution):
self.skipTest('caused segfault with TPU in eager mode.')
if mode == 'graph' and _is_tpu_strategy(distribution):
self.skipTest('partial batch not supported with TPU in graph mode.')
with self.cached_session():
with distribution.scope():
optimizer_fn = gradient_descent_keras.SGD
optimizer = optimizer_fn(0.001)
model = get_model()
loss = 'mse'
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
model.compile(optimizer, loss, metrics=metrics)
inputs = np.zeros((20, 3), dtype=np.float32)
# steps/steps_per_epoch are calculated when using numpy arrays as
# input data.
predict_with_numpy = model.predict(inputs, batch_size=10)
predict_dataset = convert_numpy_to_dataset_with_unknown_cardinality(
inputs)
self.assertEqual(
keras.backend.get_value(cardinality.cardinality(predict_dataset)),
cardinality.UNKNOWN)
predict_with_ds = model.predict(predict_dataset)
self.assertAllClose(
predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4)
@ds_combinations.generate(all_strategy_combinations())
def test_on_dataset_with_unknown_cardinality_without_steps(
self, distribution, mode):
# TODO(b/155867206): Investigate why this test occasionally segfaults on TPU
# in eager mode.
if mode == 'eager' and isinstance(
distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
if mode == 'eager' and _is_tpu_strategy(distribution):
self.skipTest('caused segfault with TPU in eager mode.')
if mode == 'graph' and isinstance(
distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
if mode == 'graph' and _is_tpu_strategy(distribution):
self.skipTest('partial batch not supported with TPU in graph mode.')
if isinstance(distribution,
collective_all_reduce_strategy.CollectiveAllReduceStrategy):
self.skipTest('EOF error causes subsequent collective ops fail.')
with self.cached_session():
with distribution.scope():
optimizer_fn = gradient_descent_keras.SGD
@ -1514,8 +1586,9 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
@ds_combinations.generate(
combinations.combine(
distribution=strategies_minus_tpu,
mode=['graph', 'eager']))
distribution=strategies_minus_tpu, mode=['graph', 'eager']) +
combinations.combine(
distribution=multi_worker_mirrored_strategies, mode=['eager']))
def test_dataset_with_sample_weights(self, distribution):
with self.cached_session(), distribution.scope():
model = get_sample_weights_model()
@ -1551,6 +1624,57 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
self.assertAllClose(result, 13.5)
def _is_tpu_strategy(strategy):
if isinstance(strategy,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
return True
return False
class TestDistributionStrategyWithDatasetsFile(test.TestCase,
parameterized.TestCase):
def setUp(self):
super(TestDistributionStrategyWithDatasetsFile, self).setUp()
self.input_file_name = os.path.join(self.get_temp_dir(), 'input.tfrecord')
inputs = np.zeros((20, 3), dtype=np.float32)
input_dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
input_dataset = input_dataset.map(parsing_ops.serialize_tensor)
writer = writers.TFRecordWriter(self.input_file_name)
writer.write(input_dataset)
# TODO(wxinyi): add a multi-worker test for TPU
@ds_combinations.generate(multi_worker_strategy_combinations_eager_only())
def test_predict_on_dataset_shard_options_file_multi_worker_mirrored(
self, distribution, mode):
# This test is to verify if we successfully switch auto_shard_policy of a
# input dataset inside model.predict with MultiWorkerMirroredStrategy to
# AutoShardPolicy.DATA. Since there is only one input file for multiple
# workers, AutoShardPolicy.AUTO or AutoShardPolicy.FILE will lead to an
# error. However, since we switch to AutoShardPolicy.DATA in model.predict,
# no error is raised.
del mode
with distribution.scope():
optimizer_fn = gradient_descent_keras.SGD
optimizer = optimizer_fn(0.001)
model = get_model()
loss = 'mse'
model.compile(optimizer, loss)
dataset = readers.TFRecordDataset(self.input_file_name)
dataset = dataset.map(lambda x: parsing_ops.parse_tensor(x, dtypes.float32))
dummy_op = lambda inp: True
dataset = dataset.filter(dummy_op).batch(8, drop_remainder=True)
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.FILE
dataset = dataset.with_options(options)
model.predict(dataset, steps=1)
class TestRegularizerLoss(test.TestCase, parameterized.TestCase):
class IdentityRegularizer(keras.regularizers.Regularizer):
@ -2196,7 +2320,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
@ds_combinations.generate(
combinations.combine(
distribution=strategies_minus_tpu,
distribution=strategies_minus_tpu + multi_worker_mirrored_strategies,
mode=['eager']))
def test_sparse_tensor_outputs(self, distribution):
@ -2225,7 +2349,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
@ds_combinations.generate(
combinations.combine(
distribution=strategies_minus_tpu,
distribution=strategies_minus_tpu + multi_worker_mirrored_strategies,
mode=['eager']))
def test_ragged_tensor_outputs(self, distribution):
@ -2252,7 +2376,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
@ds_combinations.generate(
combinations.combine(
distribution=strategies_minus_default_minus_tpu + tpu_strategies,
distribution=strategies_minus_default_minus_tpu + tpu_strategies +
multi_worker_mirrored_strategies,
mode=['eager']))
def test_correctness_of_add_loss_with_merge_call(self, distribution):
batch_size = 32
@ -2568,4 +2693,4 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
if __name__ == '__main__':
base_layer_utils.enable_v2_dtype_behavior()
test.main()
multi_process_runner.test_main()

View File

@ -24,7 +24,6 @@ import numpy as np
import six
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import strategy_combinations
@ -287,12 +286,7 @@ def fit_eval_and_predict(initial_weights,
result['weights_1'] = model.get_weights()
# TODO(b/157924053): Now model.predict() doesn't support
# MultiWorkerMirroredStrategy. Enable model.predict() after it's supported.
if predict_inputs is not None and not isinstance(
distribution,
(collective_all_reduce_strategy.CollectiveAllReduceStrategy,
collective_all_reduce_strategy.CollectiveAllReduceStrategyV1)):
if predict_inputs is not None:
# Check correctness of the result of predict() invoked
# multiple times -- as for stateful models, result of
# predict may differ for each batch.

View File

@ -23,9 +23,11 @@ import tempfile
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import values
@ -575,4 +577,4 @@ class TestDistributionStrategyWithStaticShapes(test.TestCase,
if __name__ == '__main__':
test.main()
multi_process_runner.test_main()

View File

@ -23,9 +23,13 @@ import itertools
import json
import os
import warnings
import six
from tensorflow.python.autograph.lang import directives
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values as ds_values
from tensorflow.python.eager import backprop
@ -1471,7 +1475,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
self.predict_function = predict_function
return self.predict_function
@disable_multi_worker
def predict(self,
x,
batch_size=None,
@ -1554,6 +1557,20 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
outputs = None
with self.distribute_strategy.scope():
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
dataset_types = (dataset_ops.DatasetV1, dataset_ops.DatasetV2)
if (self._in_multi_worker_mode() or _is_tpu_multi_host(
self.distribute_strategy)) and isinstance(x, dataset_types):
try:
options = dataset_ops.Options()
data_option = AutoShardPolicy.DATA
options.experimental_distribute.auto_shard_policy = data_option
x = x.with_options(options)
except ValueError:
warnings.warn('Using Model.predict with '
'MultiWorkerDistributionStrategy or TPUStrategy and '
'AutoShardPolicy.FILE might lead to out-of-order result'
'. Consider setting it to AutoShardPolicy.DATA.')
data_handler = data_adapter.DataHandler(
x=x,
batch_size=batch_size,
@ -2658,6 +2675,8 @@ def reduce_per_replica(values, strategy, reduction='first'):
def _reduce(v):
"""Reduce a single `PerReplica` object."""
if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy):
return _multi_worker_concat(v, strategy)
if not isinstance(v, ds_values.PerReplica):
return v
elif reduction == 'first':
@ -2702,6 +2721,42 @@ def _tpu_multi_host_concat(v, strategy):
return concat(ordered_replicas)
def _collective_all_reduce_multi_worker(strategy):
return (isinstance(strategy,
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
) and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather
# for all strategies
def _multi_worker_concat(v, strategy):
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access
# v might not have the same shape on different replicas
if isinstance(v, ds_values.PerReplica):
shapes = array_ops.concat([
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
for single_value in v.values
],
axis=0)
all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access
else:
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
all_shapes = strategy._gather( # pylint: disable=protected-access
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
axis=0)
replicas = array_ops.split(
replicas,
num_or_size_splits=all_shapes,
num=strategy.num_replicas_in_sync)
ordered_replicas = []
num_replicas_per_worker = len(strategy.extended.worker_devices)
for replica_id in range(num_replicas_per_worker):
ordered_replicas += replicas[replica_id::num_replicas_per_worker]
return concat(ordered_replicas)
def _is_scalar(x):
return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0