model.predict for MultiWorkerMirroredStrategy.
PiperOrigin-RevId: 335892290 Change-Id: I9202fcecae62ed78948c4c93d9f19d077c6928a7
This commit is contained in:
parent
811248e122
commit
37995ed712
tensorflow/python/keras
@ -337,6 +337,7 @@ distribute_py_test(
|
||||
srcs = ["distribute_strategy_test.py"],
|
||||
full_precision = True,
|
||||
main = "distribute_strategy_test.py",
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
|
@ -17,15 +17,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.experimental.ops import writers
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.distribute import central_storage_strategy
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
@ -36,6 +44,7 @@ from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.keras import testing_utils
|
||||
@ -50,6 +59,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.losses import loss_reduction
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -152,8 +162,7 @@ def batch_wrapper(dataset, batch_size, distribution, repeat=None):
|
||||
dataset = dataset.repeat(repeat)
|
||||
# TPUs currently require fully defined input shapes, drop_remainder ensures
|
||||
# the input will have fully defined shapes.
|
||||
if isinstance(distribution,
|
||||
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
||||
if _is_tpu_strategy(distribution):
|
||||
return dataset.batch(batch_size, drop_remainder=True)
|
||||
else:
|
||||
return dataset.batch(batch_size)
|
||||
@ -237,11 +246,18 @@ strategies_minus_tpu = [
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
||||
]
|
||||
|
||||
multi_worker_mirrored_strategies = [
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
]
|
||||
|
||||
tpu_strategies = [
|
||||
strategy_combinations.tpu_strategy,
|
||||
]
|
||||
|
||||
all_strategies = strategies_minus_tpu + tpu_strategies
|
||||
all_strategies = (
|
||||
strategies_minus_tpu + tpu_strategies + multi_worker_mirrored_strategies)
|
||||
|
||||
|
||||
def strategy_minus_tpu_combinations():
|
||||
@ -258,8 +274,14 @@ def tpu_strategy_combinations_graph_only():
|
||||
return combinations.combine(distribution=tpu_strategies, mode=['graph'])
|
||||
|
||||
|
||||
def multi_worker_strategy_combinations_eager_only():
|
||||
return combinations.combine(
|
||||
distribution=multi_worker_mirrored_strategies, mode=['eager'])
|
||||
|
||||
|
||||
def all_strategy_combinations():
|
||||
return strategy_minus_tpu_combinations() + tpu_strategy_combinations()
|
||||
return strategy_minus_tpu_combinations() + tpu_strategy_combinations(
|
||||
) + multi_worker_strategy_combinations_eager_only()
|
||||
|
||||
|
||||
def all_strategy_minus_default_and_tpu_combinations():
|
||||
@ -275,7 +297,8 @@ def all_strategy_minus_default_and_tpu_combinations():
|
||||
|
||||
def all_strategy_combinations_minus_default():
|
||||
return (all_strategy_minus_default_and_tpu_combinations() +
|
||||
tpu_strategy_combinations())
|
||||
tpu_strategy_combinations() +
|
||||
multi_worker_strategy_combinations_eager_only())
|
||||
|
||||
|
||||
def strategy_and_optimizer_combinations():
|
||||
@ -318,7 +341,21 @@ def strategy_and_optimizer_combinations():
|
||||
optimizer_combinations.gradient_descent_optimizer_keras_v2_fn,
|
||||
optimizer_combinations.rmsprop_optimizer_keras_v2_fn
|
||||
])
|
||||
return non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph
|
||||
multi_worker_eager = combinations.combine(
|
||||
distribution=multi_worker_mirrored_strategies,
|
||||
mode=['eager'],
|
||||
optimizer=[
|
||||
optimizer_combinations.adadelta_optimizer_keras_v2_fn,
|
||||
optimizer_combinations.adagrad_optimizer_keras_v2_fn,
|
||||
optimizer_combinations.adam_optimizer_keras_v2_fn,
|
||||
optimizer_combinations.adamax_optimizer_keras_v2_fn,
|
||||
optimizer_combinations.gradient_descent_optimizer_keras_v2_fn,
|
||||
optimizer_combinations.nadam_optimizer_keras_v2_fn,
|
||||
optimizer_combinations.rmsprop_optimizer_keras_v2_fn,
|
||||
optimizer_combinations.ftrl_optimizer_keras_v2_fn
|
||||
])
|
||||
return (non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph +
|
||||
multi_worker_eager)
|
||||
|
||||
|
||||
class BatchCountingCB(keras.callbacks.Callback):
|
||||
@ -494,8 +531,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
if isinstance(distribution.extended,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended):
|
||||
self.skipTest('b/152097775')
|
||||
if isinstance(distribution,
|
||||
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
||||
if _is_tpu_strategy(distribution):
|
||||
policy_name = 'mixed_bfloat16'
|
||||
else:
|
||||
policy_name = 'mixed_float16'
|
||||
@ -545,8 +581,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended):
|
||||
self.skipTest('b/152097775')
|
||||
|
||||
if isinstance(distribution,
|
||||
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
||||
if _is_tpu_strategy(distribution):
|
||||
policy_name = 'mixed_bfloat16'
|
||||
else:
|
||||
policy_name = 'mixed_float16'
|
||||
@ -627,8 +662,9 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategies_minus_tpu,
|
||||
mode=['graph', 'eager']))
|
||||
distribution=strategies_minus_tpu, mode=['graph', 'eager']) +
|
||||
combinations.combine(
|
||||
distribution=multi_worker_mirrored_strategies, mode=['eager']))
|
||||
def test_numpy_with_sample_weights(self, distribution):
|
||||
with self.cached_session(), distribution.scope():
|
||||
model = get_sample_weights_model()
|
||||
@ -989,8 +1025,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
def test_fit_with_dictionary_in_the_dataset_b135161171(
|
||||
self, distribution):
|
||||
|
||||
if isinstance(distribution,
|
||||
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
||||
if _is_tpu_strategy(distribution):
|
||||
self.skipTest('b/142805125')
|
||||
|
||||
def custom_loss(predict, label, weight):
|
||||
@ -1069,19 +1104,56 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
self.assertAllClose(
|
||||
predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@ds_combinations.generate(all_strategy_combinations())
|
||||
def test_predict_on_dataset_with_unknown_cardinality_without_steps(
|
||||
self, distribution, mode):
|
||||
# TODO(b/155867206): Investigate why this test occasionally segfaults on TPU
|
||||
# in eager mode.
|
||||
if mode == 'eager' and _is_tpu_strategy(distribution):
|
||||
self.skipTest('caused segfault with TPU in eager mode.')
|
||||
|
||||
if mode == 'graph' and _is_tpu_strategy(distribution):
|
||||
self.skipTest('partial batch not supported with TPU in graph mode.')
|
||||
|
||||
with self.cached_session():
|
||||
with distribution.scope():
|
||||
optimizer_fn = gradient_descent_keras.SGD
|
||||
optimizer = optimizer_fn(0.001)
|
||||
model = get_model()
|
||||
loss = 'mse'
|
||||
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
|
||||
model.compile(optimizer, loss, metrics=metrics)
|
||||
|
||||
inputs = np.zeros((20, 3), dtype=np.float32)
|
||||
# steps/steps_per_epoch are calculated when using numpy arrays as
|
||||
# input data.
|
||||
predict_with_numpy = model.predict(inputs, batch_size=10)
|
||||
|
||||
predict_dataset = convert_numpy_to_dataset_with_unknown_cardinality(
|
||||
inputs)
|
||||
|
||||
self.assertEqual(
|
||||
keras.backend.get_value(cardinality.cardinality(predict_dataset)),
|
||||
cardinality.UNKNOWN)
|
||||
|
||||
predict_with_ds = model.predict(predict_dataset)
|
||||
self.assertAllClose(
|
||||
predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@ds_combinations.generate(all_strategy_combinations())
|
||||
def test_on_dataset_with_unknown_cardinality_without_steps(
|
||||
self, distribution, mode):
|
||||
# TODO(b/155867206): Investigate why this test occasionally segfaults on TPU
|
||||
# in eager mode.
|
||||
if mode == 'eager' and isinstance(
|
||||
distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
||||
if mode == 'eager' and _is_tpu_strategy(distribution):
|
||||
self.skipTest('caused segfault with TPU in eager mode.')
|
||||
|
||||
if mode == 'graph' and isinstance(
|
||||
distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
||||
if mode == 'graph' and _is_tpu_strategy(distribution):
|
||||
self.skipTest('partial batch not supported with TPU in graph mode.')
|
||||
|
||||
if isinstance(distribution,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy):
|
||||
self.skipTest('EOF error causes subsequent collective ops fail.')
|
||||
with self.cached_session():
|
||||
with distribution.scope():
|
||||
optimizer_fn = gradient_descent_keras.SGD
|
||||
@ -1514,8 +1586,9 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategies_minus_tpu,
|
||||
mode=['graph', 'eager']))
|
||||
distribution=strategies_minus_tpu, mode=['graph', 'eager']) +
|
||||
combinations.combine(
|
||||
distribution=multi_worker_mirrored_strategies, mode=['eager']))
|
||||
def test_dataset_with_sample_weights(self, distribution):
|
||||
with self.cached_session(), distribution.scope():
|
||||
model = get_sample_weights_model()
|
||||
@ -1551,6 +1624,57 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
self.assertAllClose(result, 13.5)
|
||||
|
||||
|
||||
def _is_tpu_strategy(strategy):
|
||||
if isinstance(strategy,
|
||||
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TestDistributionStrategyWithDatasetsFile(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestDistributionStrategyWithDatasetsFile, self).setUp()
|
||||
self.input_file_name = os.path.join(self.get_temp_dir(), 'input.tfrecord')
|
||||
inputs = np.zeros((20, 3), dtype=np.float32)
|
||||
input_dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
|
||||
input_dataset = input_dataset.map(parsing_ops.serialize_tensor)
|
||||
writer = writers.TFRecordWriter(self.input_file_name)
|
||||
writer.write(input_dataset)
|
||||
|
||||
# TODO(wxinyi): add a multi-worker test for TPU
|
||||
@ds_combinations.generate(multi_worker_strategy_combinations_eager_only())
|
||||
def test_predict_on_dataset_shard_options_file_multi_worker_mirrored(
|
||||
self, distribution, mode):
|
||||
# This test is to verify if we successfully switch auto_shard_policy of a
|
||||
# input dataset inside model.predict with MultiWorkerMirroredStrategy to
|
||||
# AutoShardPolicy.DATA. Since there is only one input file for multiple
|
||||
# workers, AutoShardPolicy.AUTO or AutoShardPolicy.FILE will lead to an
|
||||
# error. However, since we switch to AutoShardPolicy.DATA in model.predict,
|
||||
# no error is raised.
|
||||
del mode
|
||||
with distribution.scope():
|
||||
optimizer_fn = gradient_descent_keras.SGD
|
||||
optimizer = optimizer_fn(0.001)
|
||||
model = get_model()
|
||||
loss = 'mse'
|
||||
model.compile(optimizer, loss)
|
||||
|
||||
dataset = readers.TFRecordDataset(self.input_file_name)
|
||||
dataset = dataset.map(lambda x: parsing_ops.parse_tensor(x, dtypes.float32))
|
||||
|
||||
dummy_op = lambda inp: True
|
||||
|
||||
dataset = dataset.filter(dummy_op).batch(8, drop_remainder=True)
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.FILE
|
||||
dataset = dataset.with_options(options)
|
||||
|
||||
model.predict(dataset, steps=1)
|
||||
|
||||
|
||||
class TestRegularizerLoss(test.TestCase, parameterized.TestCase):
|
||||
|
||||
class IdentityRegularizer(keras.regularizers.Regularizer):
|
||||
@ -2196,7 +2320,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategies_minus_tpu,
|
||||
distribution=strategies_minus_tpu + multi_worker_mirrored_strategies,
|
||||
mode=['eager']))
|
||||
def test_sparse_tensor_outputs(self, distribution):
|
||||
|
||||
@ -2225,7 +2349,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategies_minus_tpu,
|
||||
distribution=strategies_minus_tpu + multi_worker_mirrored_strategies,
|
||||
mode=['eager']))
|
||||
def test_ragged_tensor_outputs(self, distribution):
|
||||
|
||||
@ -2252,7 +2376,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategies_minus_default_minus_tpu + tpu_strategies,
|
||||
distribution=strategies_minus_default_minus_tpu + tpu_strategies +
|
||||
multi_worker_mirrored_strategies,
|
||||
mode=['eager']))
|
||||
def test_correctness_of_add_loss_with_merge_call(self, distribution):
|
||||
batch_size = 32
|
||||
@ -2568,4 +2693,4 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
|
||||
|
||||
if __name__ == '__main__':
|
||||
base_layer_utils.enable_v2_dtype_behavior()
|
||||
test.main()
|
||||
multi_process_runner.test_main()
|
||||
|
@ -24,7 +24,6 @@ import numpy as np
|
||||
import six
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
@ -287,12 +286,7 @@ def fit_eval_and_predict(initial_weights,
|
||||
|
||||
result['weights_1'] = model.get_weights()
|
||||
|
||||
# TODO(b/157924053): Now model.predict() doesn't support
|
||||
# MultiWorkerMirroredStrategy. Enable model.predict() after it's supported.
|
||||
if predict_inputs is not None and not isinstance(
|
||||
distribution,
|
||||
(collective_all_reduce_strategy.CollectiveAllReduceStrategy,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategyV1)):
|
||||
if predict_inputs is not None:
|
||||
# Check correctness of the result of predict() invoked
|
||||
# multiple times -- as for stateful models, result of
|
||||
# predict may differ for each batch.
|
||||
|
@ -23,9 +23,11 @@ import tempfile
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute import values
|
||||
@ -575,4 +577,4 @@ class TestDistributionStrategyWithStaticShapes(test.TestCase,
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
multi_process_runner.test_main()
|
||||
|
@ -23,9 +23,13 @@ import itertools
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.autograph.lang import directives
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.eager import backprop
|
||||
@ -1471,7 +1475,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
self.predict_function = predict_function
|
||||
return self.predict_function
|
||||
|
||||
@disable_multi_worker
|
||||
def predict(self,
|
||||
x,
|
||||
batch_size=None,
|
||||
@ -1554,6 +1557,20 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
outputs = None
|
||||
with self.distribute_strategy.scope():
|
||||
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
||||
dataset_types = (dataset_ops.DatasetV1, dataset_ops.DatasetV2)
|
||||
if (self._in_multi_worker_mode() or _is_tpu_multi_host(
|
||||
self.distribute_strategy)) and isinstance(x, dataset_types):
|
||||
try:
|
||||
options = dataset_ops.Options()
|
||||
data_option = AutoShardPolicy.DATA
|
||||
options.experimental_distribute.auto_shard_policy = data_option
|
||||
x = x.with_options(options)
|
||||
except ValueError:
|
||||
warnings.warn('Using Model.predict with '
|
||||
'MultiWorkerDistributionStrategy or TPUStrategy and '
|
||||
'AutoShardPolicy.FILE might lead to out-of-order result'
|
||||
'. Consider setting it to AutoShardPolicy.DATA.')
|
||||
|
||||
data_handler = data_adapter.DataHandler(
|
||||
x=x,
|
||||
batch_size=batch_size,
|
||||
@ -2658,6 +2675,8 @@ def reduce_per_replica(values, strategy, reduction='first'):
|
||||
|
||||
def _reduce(v):
|
||||
"""Reduce a single `PerReplica` object."""
|
||||
if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy):
|
||||
return _multi_worker_concat(v, strategy)
|
||||
if not isinstance(v, ds_values.PerReplica):
|
||||
return v
|
||||
elif reduction == 'first':
|
||||
@ -2702,6 +2721,42 @@ def _tpu_multi_host_concat(v, strategy):
|
||||
return concat(ordered_replicas)
|
||||
|
||||
|
||||
def _collective_all_reduce_multi_worker(strategy):
|
||||
return (isinstance(strategy,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
|
||||
) and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
|
||||
|
||||
|
||||
# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather
|
||||
# for all strategies
|
||||
def _multi_worker_concat(v, strategy):
|
||||
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
|
||||
replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access
|
||||
# v might not have the same shape on different replicas
|
||||
if isinstance(v, ds_values.PerReplica):
|
||||
shapes = array_ops.concat([
|
||||
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
|
||||
for single_value in v.values
|
||||
],
|
||||
axis=0)
|
||||
all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access
|
||||
else:
|
||||
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
|
||||
all_shapes = strategy._gather( # pylint: disable=protected-access
|
||||
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
|
||||
axis=0)
|
||||
|
||||
replicas = array_ops.split(
|
||||
replicas,
|
||||
num_or_size_splits=all_shapes,
|
||||
num=strategy.num_replicas_in_sync)
|
||||
ordered_replicas = []
|
||||
num_replicas_per_worker = len(strategy.extended.worker_devices)
|
||||
for replica_id in range(num_replicas_per_worker):
|
||||
ordered_replicas += replicas[replica_id::num_replicas_per_worker]
|
||||
return concat(ordered_replicas)
|
||||
|
||||
|
||||
def _is_scalar(x):
|
||||
return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user