Drop experimental and v2 qualifiers from Strategy experimental_run_v2 method.

- experimental_run_v2 -> run

PiperOrigin-RevId: 300574367
Change-Id: I5d82ea5450a4d32aea6d05ed3db4f02b8edb2eea
This commit is contained in:
Ken Franko 2020-03-12 10:23:35 -07:00 committed by TensorFlower Gardener
parent fbbb83b995
commit 0b8f0a5b84
56 changed files with 322 additions and 244 deletions

View File

@ -92,7 +92,7 @@ class DistributedDumpingCallbackTest(
caught_error = None caught_error = None
try: try:
distribution.experimental_run_v2(train_step) distribution.run(train_step)
except errors.InvalidArgumentError as error: except errors.InvalidArgumentError as error:
caught_error = error caught_error = error
self.assertTrue(caught_error) self.assertTrue(caught_error)
@ -128,7 +128,7 @@ class DistributedDumpingCallbackTest(
grads_and_vars = zip(grads, mini_model.weights) grads_and_vars = zip(grads, mini_model.weights)
optimizer.apply_gradients(grads_and_vars) optimizer.apply_gradients(grads_and_vars)
distribution.experimental_run_v2(train_step) distribution.run(train_step)
updated_var_values = self.evaluate(mini_model.variables) updated_var_values = self.evaluate(mini_model.variables)
num_devices = len(distribution.extended.worker_devices) num_devices = len(distribution.extended.worker_devices)

View File

@ -67,7 +67,7 @@ def train_step(iterator):
grads = tape.gradient(loss, model.variables) grads = tape.gradient(loss, model.variables)
return grads return grads
return tpu_strategy.experimental_run_v2( return tpu_strategy.run(
step_fn, args=(next(iterator),)) step_fn, args=(next(iterator),))
# Run the loop body once on at dataset. # Run the loop body once on at dataset.

View File

@ -48,7 +48,7 @@ class CentralStorageStrategy(distribute_lib.Strategy):
# Iterate over the distributed dataset # Iterate over the distributed dataset
for x in dist_dataset: for x in dist_dataset:
# process dataset elements # process dataset elements
strategy.experimental_run_v2(train_step, args=(x,)) strategy.run(train_step, args=(x,))
``` ```
""" """
@ -125,7 +125,7 @@ class CentralStorageStrategy(distribute_lib.Strategy):
inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn) inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn)
for batch in inputs: for batch in inputs:
replica_results = strategy.experimental_run_v2(replica_fn, args=(batch,)) replica_results = strategy.run(replica_fn, args=(batch,))
``` ```
IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
@ -152,8 +152,8 @@ class CentralStorageStrategy(distribute_lib.Strategy):
will be all the values on that worker. will be all the values on that worker.
Args: Args:
value: A value returned by `experimental_run()`, `experimental_run_v2()`, value: A value returned by `run()`, `extended.call_for_each_replica()`,
`extended.call_for_each_replica()`, or a variable created in `scope`. or a variable created in `scope`.
Returns: Returns:
A tuple of values contained in `value`. If `value` represents a single A tuple of values contained in `value`. If `value` represents a single
@ -161,7 +161,7 @@ class CentralStorageStrategy(distribute_lib.Strategy):
""" """
return super(CentralStorageStrategy, self).experimental_local_results(value) return super(CentralStorageStrategy, self).experimental_local_results(value)
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation def run(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation
"""Run `fn` on each replica, with the given arguments. """Run `fn` on each replica, with the given arguments.
In `CentralStorageStrategy`, `fn` is called on each of the compute In `CentralStorageStrategy`, `fn` is called on each of the compute
@ -177,13 +177,12 @@ class CentralStorageStrategy(distribute_lib.Strategy):
Returns: Returns:
Return value from running `fn`. Return value from running `fn`.
""" """
return super(CentralStorageStrategy, return super(CentralStorageStrategy, self).run(fn, args, kwargs, options)
self).experimental_run_v2(fn, args, kwargs, options)
def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation
"""Reduce `value` across replicas. """Reduce `value` across replicas.
Given a per-replica value returned by `experimental_run_v2`, say a Given a per-replica value returned by `run`, say a
per-example loss, the batch will be divided across all the replicas. This per-example loss, the batch will be divided across all the replicas. This
function allows you to aggregate across replicas and optionally also across function allows you to aggregate across replicas and optionally also across
batch elements. For example, if you have a global batch size of 8 and 2 batch elements. For example, if you have a global batch size of 8 and 2
@ -221,7 +220,7 @@ class CentralStorageStrategy(distribute_lib.Strategy):
# Iterate over the distributed dataset # Iterate over the distributed dataset
for x in dist_dataset: for x in dist_dataset:
result = strategy.experimental_run_v2(train_step, args=(x,)) result = strategy.run(train_step, args=(x,))
result = strategy.reduce(tf.distribute.ReduceOp.SUM, result, result = strategy.reduce(tf.distribute.ReduceOp.SUM, result,
axis=None).numpy() axis=None).numpy()
@ -234,7 +233,7 @@ class CentralStorageStrategy(distribute_lib.Strategy):
Args: Args:
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
be combined. be combined.
value: A "per replica" value, e.g. returned by `experimental_run_v2` to value: A "per replica" value, e.g. returned by `run` to
be combined into a single tensor. be combined into a single tensor.
axis: Specifies the dimension to reduce along within each axis: Specifies the dimension to reduce along within each
replica's tensor. Should typically be set to the batch dimension, or replica's tensor. Should typically be set to the batch dimension, or

View File

@ -118,7 +118,8 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
loss = v + v loss = v + v
gradients = tape.gradient(loss, [v]) gradients = tape.gradient(loss, [v])
opt.apply_gradients(zip(gradients, [v])) opt.apply_gradients(zip(gradients, [v]))
distribution.experimental_run_v2(f)
distribution.run(f)
return v, opt, step return v, opt, step

View File

@ -125,8 +125,7 @@ def iteration_inside_func(initial_weights, dataset, optimizer_fn,
if iteration_type == 'dataset': if iteration_type == 'dataset':
for x in dist_input: for x in dist_input:
if strategy: if strategy:
per_replica_losses = strategy.experimental_run_v2(step_fn, per_replica_losses = strategy.run(step_fn, args=(x,))
args=(x,))
total_loss += strategy.reduce(reduce_util.ReduceOp.SUM, total_loss += strategy.reduce(reduce_util.ReduceOp.SUM,
per_replica_losses, per_replica_losses,
axis=None) axis=None)
@ -137,8 +136,7 @@ def iteration_inside_func(initial_weights, dataset, optimizer_fn,
iterator = iter(dist_input) iterator = iter(dist_input)
for _ in range(_STEPS_PER_EPOCH): for _ in range(_STEPS_PER_EPOCH):
if strategy: if strategy:
per_replica_losses = strategy.experimental_run_v2( per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
step_fn, args=(next(iterator),))
total_loss += strategy.reduce(reduce_util.ReduceOp.SUM, total_loss += strategy.reduce(reduce_util.ReduceOp.SUM,
per_replica_losses, per_replica_losses,
axis=None) axis=None)
@ -184,8 +182,7 @@ def iteration_outside_func(initial_weights, dataset, optimizer_fn,
return loss return loss
if strategy: if strategy:
per_replica_losses = strategy.experimental_run_v2( per_replica_losses = strategy.run(step_fn, args=(dist_inputs,))
step_fn, args=(dist_inputs,))
return strategy.reduce(reduce_util.ReduceOp.SUM, return strategy.reduce(reduce_util.ReduceOp.SUM,
per_replica_losses, per_replica_losses,
axis=None) axis=None)

View File

@ -87,7 +87,7 @@ class GradientTapeTest(test.TestCase, parameterized.TestCase,
results = [] results = []
for x in dist_dataset: for x in dist_dataset:
output = distribution.experimental_local_results( output = distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(x,))) distribution.run(train_step, args=(x,)))
results.append(output) results.append(output)
self.assert_equal_flattened([[10., 12.], [14., 16.]], results) self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
@ -110,7 +110,7 @@ class GradientTapeTest(test.TestCase, parameterized.TestCase,
grads = tape.gradient(y, x) grads = tape.gradient(y, x)
return grads return grads
return distribution.experimental_local_results( return distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(x,))) distribution.run(train_step, args=(x,)))
dist_dataset = distribution.experimental_distribute_dataset(dataset) dist_dataset = distribution.experimental_distribute_dataset(dataset)
results = [] results = []
@ -141,7 +141,7 @@ class GradientTapeTest(test.TestCase, parameterized.TestCase,
with backprop.GradientTape() as tape: with backprop.GradientTape() as tape:
y = model(x) y = model(x)
return tape.gradient(y, x) return tape.gradient(y, x)
return distribution.experimental_run_v2(replica_step) return distribution.run(replica_step)
grads = distribution.experimental_local_results(train_step()) grads = distribution.experimental_local_results(train_step())
self.assertLen(grads, distribution.num_replicas_in_sync) self.assertLen(grads, distribution.num_replicas_in_sync)

View File

@ -87,7 +87,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
return math_ops.square(x) return math_ops.square(x)
outputs = distribution.experimental_local_results( outputs = distribution.experimental_local_results(
distribution.experimental_run_v2(computation, args=(x,))) distribution.run(computation, args=(x,)))
return outputs return outputs
self.assertAllEqual( self.assertAllEqual(
@ -110,7 +110,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
def assign_add(): def assign_add():
v.assign_add(1.0) v.assign_add(1.0)
distribution.experimental_run_v2(assign_add) distribution.run(assign_add)
return array_ops.zeros([]) return array_ops.zeros([])
train_step() train_step()
@ -130,7 +130,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
results = [] results = []
for x in dist_dataset: for x in dist_dataset:
output = distribution.experimental_local_results( output = distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(x,))) distribution.run(train_step, args=(x,)))
results.append(output) results.append(output)
self.assert_equal_flattened([[25., 36.], [49., 64.]], results) self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
@ -148,7 +148,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
with self.assertRaisesRegexp(NotImplementedError, with self.assertRaisesRegexp(NotImplementedError,
"does not support pure eager execution"): "does not support pure eager execution"):
distribution.experimental_run_v2(train_step, args=(next(input_iterator),)) distribution.run(train_step, args=(next(input_iterator),))
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
@ -166,7 +166,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
results = [] results = []
for x in dist_dataset: for x in dist_dataset:
output = distribution.experimental_local_results( output = distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(x,))) distribution.run(train_step, args=(x,)))
results.append(output) results.append(output)
self.assert_equal_flattened([[25., 36.], [49., 64.]], results) self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
@ -184,7 +184,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
@def_function.function @def_function.function
def f_train_step(input_data): def f_train_step(input_data):
return distribution.experimental_local_results( return distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(input_data,))) distribution.run(train_step, args=(input_data,)))
dist_dataset = distribution.experimental_distribute_dataset(dataset) dist_dataset = distribution.experimental_distribute_dataset(dataset)
results = [] results = []
@ -214,7 +214,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
}] }]
inputs = next(iterator) inputs = next(iterator)
outputs = distribution.experimental_run_v2(computation, args=(inputs,)) outputs = distribution.run(computation, args=(inputs,))
return nest.map_structure(distribution.experimental_local_results, return nest.map_structure(distribution.experimental_local_results,
outputs) outputs)
@ -238,7 +238,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
@def_function.function @def_function.function
def f_train_step(input_data): def f_train_step(input_data):
return distribution.experimental_local_results( return distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(input_data,))) distribution.run(train_step, args=(input_data,)))
dist_dataset = distribution.experimental_distribute_dataset(dataset) dist_dataset = distribution.experimental_distribute_dataset(dataset)
results = [] results = []
@ -270,7 +270,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
distribution.reduce("MEAN", x, axis=0), product_of_means.dtype) distribution.reduce("MEAN", x, axis=0), product_of_means.dtype)
for y in dist_dataset: # loop with no intermediate state for y in dist_dataset: # loop with no intermediate state
distribution.experimental_run_v2(train_step, args=(y,)) distribution.run(train_step, args=(y,))
return number_of_steps, product_of_means return number_of_steps, product_of_means
@ -308,7 +308,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
for _ in range(2): for _ in range(2):
elem = next(iterator) elem = next(iterator)
output = distribution.experimental_local_results( output = distribution.experimental_local_results(
distribution.experimental_run_v2(step_fn, args=(elem,))) distribution.run(step_fn, args=(elem,)))
results.append(output) results.append(output)
return results return results
@ -454,7 +454,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
return math_ops.reduce_mean(x) return math_ops.reduce_mean(x)
inputs = next(iterator) inputs = next(iterator)
outputs = distribution.experimental_local_results( outputs = distribution.experimental_local_results(
distribution.experimental_run_v2(computation, args=(inputs,))) distribution.run(computation, args=(inputs,)))
return outputs return outputs
# This assumes that there are exactly 2 replicas # This assumes that there are exactly 2 replicas
@ -478,7 +478,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
inputs = next(iterator) inputs = next(iterator)
outputs = distribution.experimental_local_results( outputs = distribution.experimental_local_results(
distribution.experimental_run_v2( distribution.run(
computation, args=(inputs,), options=options)) computation, args=(inputs,), options=options))
return outputs return outputs
@ -499,7 +499,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
def computation(x): def computation(x):
return math_ops.reduce_mean(x) return math_ops.reduce_mean(x)
outputs = distribution.experimental_local_results( outputs = distribution.experimental_local_results(
distribution.experimental_run_v2(computation, args=(inputs,))) distribution.run(computation, args=(inputs,)))
return outputs return outputs
# This assumes that there are exactly 2 replicas # This assumes that there are exactly 2 replicas
@ -552,7 +552,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
def computation(x): def computation(x):
return array_ops.size_v2(x) return array_ops.size_v2(x)
outputs = distribution.experimental_local_results( outputs = distribution.experimental_local_results(
distribution.experimental_run_v2(computation, args=(inputs,))) distribution.run(computation, args=(inputs,)))
return outputs return outputs
# This assumes that there are exactly 2 replicas # This assumes that there are exactly 2 replicas
@ -580,7 +580,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
def computation(x): def computation(x):
return math_ops.reduce_mean(x) return math_ops.reduce_mean(x)
outputs = distribution.experimental_local_results( outputs = distribution.experimental_local_results(
distribution.experimental_run_v2(computation, args=(inputs,))) distribution.run(computation, args=(inputs,)))
return outputs return outputs
# This assumes that there are exactly 2 replicas # This assumes that there are exactly 2 replicas
@ -669,7 +669,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
# Fixed size output with a dynamic sized output. # Fixed size output with a dynamic sized output.
return array_ops.zeros([3]), math_ops.square(x) return array_ops.zeros([3]), math_ops.square(x)
return distribution.experimental_run_v2( return distribution.run(
computation, args=(next(iterator),)) computation, args=(next(iterator),))
results = run(input_iterator) results = run(input_iterator)
@ -707,7 +707,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
for _ in range(2): for _ in range(2):
elem = next(iterator) elem = next(iterator)
output = distribution.experimental_local_results( output = distribution.experimental_local_results(
distribution.experimental_run_v2(step_fn, args=(elem,))) distribution.run(step_fn, args=(elem,)))
results.append(output) results.append(output)
return results return results
@ -729,7 +729,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
@def_function.function @def_function.function
def f_train_step(input_data): def f_train_step(input_data):
return distribution.experimental_local_results( return distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(input_data,))) distribution.run(train_step, args=(input_data,)))
dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
dist_dataset = distribution.experimental_distribute_dataset(dataset) dist_dataset = distribution.experimental_distribute_dataset(dataset)
@ -761,12 +761,12 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
def func(inputs): def func(inputs):
return math_ops.square(inputs) + var return math_ops.square(inputs) + var
per_replica_outputs = distribution.experimental_run_v2( per_replica_outputs = distribution.run(
func, (next(input_iterator),)) func, (next(input_iterator),))
mean = distribution.reduce( mean = distribution.reduce(
reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None) reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None)
for _ in dataset_ops.Dataset.range(1): for _ in dataset_ops.Dataset.range(1):
per_replica_outputs = distribution.experimental_run_v2( per_replica_outputs = distribution.run(
func, (next(input_iterator),)) func, (next(input_iterator),))
mean = distribution.reduce( mean = distribution.reduce(
reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None) reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None)
@ -793,7 +793,7 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
@def_function.function @def_function.function
def f_train_step(iterator): def f_train_step(iterator):
distribution.experimental_run_v2(train_step, args=(next(iterator),)) distribution.run(train_step, args=(next(iterator),))
return a return a
dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)

View File

@ -49,7 +49,7 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
loss_metric.update_state(loss) loss_metric.update_state(loss)
loss_metric_2.update_state(loss) loss_metric_2.update_state(loss)
distribution.experimental_run_v2(step_fn) distribution.run(step_fn)
train_step() train_step()
self.assertEqual(loss_metric.result().numpy(), self.assertEqual(loss_metric.result().numpy(),
@ -73,7 +73,7 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
metric.update_state(i) metric.update_state(i)
for i in dataset: for i in dataset:
distribution.experimental_run_v2(step_fn, args=(i,)) distribution.run(step_fn, args=(i,))
# This should be the mean of integers 0-9 which has a sum of 45 and a count # This should be the mean of integers 0-9 which has a sum of 45 and a count
# of 10 resulting in mean of 4.5. # of 10 resulting in mean of 4.5.

View File

@ -75,7 +75,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
grads = tape.gradient(loss, model.variables) grads = tape.gradient(loss, model.variables)
return grads return grads
outputs = distribution.experimental_run_v2( outputs = distribution.run(
step_fn, args=(next(iterator),)) step_fn, args=(next(iterator),))
return nest.map_structure(distribution.experimental_local_results, return nest.map_structure(distribution.experimental_local_results,
outputs) outputs)
@ -104,7 +104,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
grads = tape.gradient(loss, model.variables) grads = tape.gradient(loss, model.variables)
return grads return grads
outputs = distribution.experimental_run_v2( outputs = distribution.run(
step_fn, args=(next(iterator),)) step_fn, args=(next(iterator),))
return nest.map_structure(distribution.experimental_local_results, return nest.map_structure(distribution.experimental_local_results,
outputs) outputs)
@ -135,7 +135,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
optimizer.apply_gradients(zip(grads, model.variables)) optimizer.apply_gradients(zip(grads, model.variables))
return loss return loss
outputs = distribution.experimental_run_v2( outputs = distribution.run(
step_fn, args=(next(iterator),)) step_fn, args=(next(iterator),))
return nest.map_structure(distribution.experimental_local_results, return nest.map_structure(distribution.experimental_local_results,
outputs) outputs)
@ -178,7 +178,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
optimizer.apply_gradients(zip(grads, model.variables)) optimizer.apply_gradients(zip(grads, model.variables))
return loss return loss
outputs = distribution.experimental_run_v2( outputs = distribution.run(
step_fn, args=(next(iterator),)) step_fn, args=(next(iterator),))
return nest.map_structure(distribution.experimental_local_results, return nest.map_structure(distribution.experimental_local_results,
outputs) outputs)
@ -210,7 +210,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
return loss return loss
for _ in range(5): for _ in range(5):
distribution.experimental_run_v2(step_fn, args=(next(iterator),)) distribution.run(step_fn, args=(next(iterator),))
train_step(input_iterator) train_step(input_iterator)
@ -261,7 +261,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
optimizer.apply_gradients(zip(grads, model.variables)) optimizer.apply_gradients(zip(grads, model.variables))
return loss return loss
outputs = distribution.experimental_run_v2( outputs = distribution.run(
step_fn, args=(next(input_iterator),)) step_fn, args=(next(input_iterator),))
return distribution.experimental_local_results(outputs) return distribution.experimental_local_results(outputs)
@ -314,7 +314,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
grads = tape.gradient(loss, model.variables) grads = tape.gradient(loss, model.variables)
optimizer.apply_gradients(zip(grads, model.variables)) optimizer.apply_gradients(zip(grads, model.variables))
distribution.experimental_run_v2(step_fn, args=(inputs,)) distribution.run(step_fn, args=(inputs,))
@def_function.function @def_function.function
def compute_loss2(images, targets): def compute_loss2(images, targets):
@ -331,7 +331,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
grads = tape.gradient(loss, model2.variables) grads = tape.gradient(loss, model2.variables)
optimizer.apply_gradients(zip(grads, model2.variables)) optimizer.apply_gradients(zip(grads, model2.variables))
distribution.experimental_run_v2(step_fn, args=(inputs,)) distribution.run(step_fn, args=(inputs,))
inputs = next(input_iterator) inputs = next(input_iterator)
@ -365,7 +365,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
grads = tape.gradient(loss, model.variables) grads = tape.gradient(loss, model.variables)
return grads return grads
outputs = distribution.experimental_run_v2( outputs = distribution.run(
step_fn, args=(next(iterator),)) step_fn, args=(next(iterator),))
return nest.map_structure(distribution.experimental_local_results, return nest.map_structure(distribution.experimental_local_results,
outputs) outputs)
@ -408,7 +408,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
grads = tape.gradient(loss, model.variables) grads = tape.gradient(loss, model.variables)
return grads return grads
outputs = distribution.experimental_run_v2( outputs = distribution.run(
step_fn, args=(next(iterator),)) step_fn, args=(next(iterator),))
return nest.map_structure(distribution.experimental_local_results, return nest.map_structure(distribution.experimental_local_results,
outputs) outputs)

View File

@ -66,7 +66,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
return v.read_value() return v.read_value()
return distribution.experimental_local_results( return distribution.experimental_local_results(
distribution.experimental_run_v2(step_fn, args=(grads,))) distribution.run(step_fn, args=(grads,)))
self.assertAllClose(optimize(), expected) self.assertAllClose(optimize(), expected)
@ -92,7 +92,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
return v.read_value() return v.read_value()
return distribution.experimental_local_results( return distribution.experimental_local_results(
distribution.experimental_run_v2(step_fn, args=(grads,))) distribution.run(step_fn, args=(grads,)))
self.assertAllClose(optimize(), [[-0.1, -0.1]]) self.assertAllClose(optimize(), [[-0.1, -0.1]])

View File

@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# pylint: disable=line-too-long
"""Library for running a computation across multiple devices. """Library for running a computation across multiple devices.
See the guide for overview and examples: See the guide for overview and examples:
[TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training), [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training),
[TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb). [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb). # pylint: disable=line-too-long
The intent of this library is that you can write an algorithm in a stylized way The intent of this library is that you can write an algorithm in a stylized way
and it will be usable with a variety of different `tf.distribute.Strategy` and it will be usable with a variety of different `tf.distribute.Strategy`
@ -130,6 +129,7 @@ from tensorflow.python.ops.losses import loss_reduction
from tensorflow.python.ops.losses import losses_impl from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.deprecation import deprecated
@ -485,7 +485,7 @@ class RunOptions(
"experimental_enable_dynamic_batch_size", "experimental_enable_dynamic_batch_size",
"experimental_bucketizing_dynamic_shape", "experimental_bucketizing_dynamic_shape",
])): ])):
"""Run options for `strategy.experimental_run_v2`. """Run options for `strategy.run`.
This can be used to hold some strategy specific configs. This can be used to hold some strategy specific configs.
@ -496,7 +496,7 @@ class RunOptions(
shape inputs are allowed. shape inputs are allowed.
experimental_bucketizing_dynamic_shape: Boolean. Only applies to experimental_bucketizing_dynamic_shape: Boolean. Only applies to
TPUStrategy. Default to False. If True, TPUStrategy will automatic TPUStrategy. Default to False. If True, TPUStrategy will automatic
bucketize inputs passed into `experimental_run_v2` if the input shape is bucketize inputs passed into `run` if the input shape is
dynamic. This is a performance optimization to reduce XLA recompilation, dynamic. This is a performance optimization to reduce XLA recompilation,
which should not have impact on correctness. which should not have impact on correctness.
""" """
@ -548,7 +548,7 @@ class StrategyBase(object):
across replicas, use across replicas, use
`tf.distribute.Strategy.experimental_distribute_datasets_from_function` `tf.distribute.Strategy.experimental_distribute_datasets_from_function`
instead. instead.
* Use `tf.distribute.Strategy.experimental_run_v2` to run a function * Use `tf.distribute.Strategy.run` to run a function
once per replica, taking values that may be "per-replica" (e.g. once per replica, taking values that may be "per-replica" (e.g.
from a distributed dataset) and returning "per-replica" values. from a distributed dataset) and returning "per-replica" values.
This function is executed in "replica context", which means each This function is executed in "replica context", which means each
@ -568,8 +568,7 @@ class StrategyBase(object):
total_result = 0 total_result = 0
for x in dataset: for x in dataset:
per_replica_result = my_strategy.experimental_run_v2(replica_fn, per_replica_result = my_strategy.run(replica_fn, args=(x,))
args=(x,))
total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
per_replica_result, axis=None) per_replica_result, axis=None)
return total_result return total_result
@ -711,7 +710,7 @@ class StrategyBase(object):
"""DEPRECATED TF 1.x ONLY.""" """DEPRECATED TF 1.x ONLY."""
with self.scope(): with self.scope():
args = (input_iterator.get_next(),) if input_iterator is not None else () args = (input_iterator.get_next(),) if input_iterator is not None else ()
return self.experimental_run_v2(fn, args=args) return self.run(fn, args=args)
def experimental_distribute_dataset(self, dataset): def experimental_distribute_dataset(self, dataset):
"""Distributes a tf.data.Dataset instance provided via `dataset`. """Distributes a tf.data.Dataset instance provided via `dataset`.
@ -736,7 +735,7 @@ class StrategyBase(object):
# Iterate over the distributed dataset # Iterate over the distributed dataset
for x in dist_dataset: for x in dist_dataset:
# process dataset elements # process dataset elements
strategy.experimental_run_v2(train_step, args=(x,)) strategy.run(train_step, args=(x,))
``` ```
We will assume that the input dataset is batched by the We will assume that the input dataset is batched by the
@ -792,7 +791,7 @@ class StrategyBase(object):
# Iterate over the distributed dataset # Iterate over the distributed dataset
for x in dist_dataset: for x in dist_dataset:
# process dataset elements # process dataset elements
strategy.experimental_run_v2(train_step, args=(x,)) strategy.run(train_step, args=(x,))
``` ```
Args: Args:
@ -836,7 +835,7 @@ class StrategyBase(object):
inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn) inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn)
for batch in inputs: for batch in inputs:
replica_results = strategy.experimental_run_v2(replica_fn, args=(batch,)) replica_results = strategy.run(replica_fn, args=(batch,))
``` ```
IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
@ -860,7 +859,7 @@ class StrategyBase(object):
return return
for _ in range(steps): for _ in range(steps):
strategy.experimental_run_v2(replica_fn_with_signature, strategy.run(replica_fn_with_signature,
args=(next(iterator),)) args=(next(iterator),))
``` ```
@ -875,24 +874,56 @@ class StrategyBase(object):
return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access
dataset_fn) dataset_fn)
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): def run(self, fn, args=(), kwargs=None, options=None):
"""Run `fn` on each replica, with the given arguments. """Run `fn` on each replica, with the given arguments.
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
"per-replica" values, such as those produced by a "distributed `Dataset`", `tf.distribute.DistributedValues`, such as those produced by a
"distributed `Dataset`" or `experimental_distribute_values_from_function`
when `fn` is executed on a particular replica, it will be executed with the when `fn` is executed on a particular replica, it will be executed with the
component of those "per-replica" values that correspond to that replica. component of `tf.distribute.DistributedValues` that correspond to that
replica.
`fn` may call `tf.distribute.get_replica_context()` to access members such `fn` may call `tf.distribute.get_replica_context()` to access members such
as `all_reduce`. as `all_reduce`.
All arguments in `args` or `kwargs` should either be nest of tensors or All arguments in `args` or `kwargs` should either be nest of tensors or
per-replica objects containing tensors or composite tensors. `tf.distribute.DistributedValues` containing tensors or composite tensors.
IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
whether eager execution is enabled, `fn` may be called one or more times ( whether eager execution is enabled, `fn` may be called one or more times (
once for each replica). once for each replica).
Example usage:
1. Constant tensor input.
>>> strategy = tf.distribute.MirroredStrategy()
>>> tensor_input = tf.constant(3.0)
>>> @tf.function
... def replica_fn(input):
... return input*2.0
>>> result = strategy.run(replica_fn, args=(tensor_input,))
>>> result
<tf.Tensor: shape=(), dtype=float32, numpy=6.0>
2. DistributedValues input.
>>> strategy = tf.distribute.MirroredStrategy()
>>> @tf.function
... def run():
... def value_fn(value_context):
... return value_context.num_replicas_in_sync
... distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
... def replica_fn2(input):
... return input*2
... return strategy.run(replica_fn2, args=(distributed_values,))
>>> result = run()
>>> result
<tf.Tensor: shape=(), dtype=int32, numpy=2>
Args: Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s. fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`. args: (Optional) Positional arguments to `fn`.
@ -903,8 +934,8 @@ class StrategyBase(object):
Returns: Returns:
Merged return value of `fn` across replicas. The structure of the return Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the value is the same as the return value from `fn`. Each element in the
structure can either be "per-replica" `Tensor` objects or `Tensor`s structure can either be `tf.distribute.DistributedValues`, `Tensor`
(for example, if running on a single replica). objects, or `Tensor`s (for example, if running on a single replica).
""" """
del options del options
@ -919,10 +950,16 @@ class StrategyBase(object):
fn, autograph_ctx.control_status_ctx(), convert_by_default=False) fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
# TODO(b/151224785): Remove deprecated alias.
@doc_controls.do_not_doc_inheritable # DEPRECATED
@deprecation.deprecated(None, "renamed to `run`")
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
return self.run(fn, args=args, kwargs=kwargs, options=options)
def reduce(self, reduce_op, value, axis): def reduce(self, reduce_op, value, axis):
"""Reduce `value` across replicas. """Reduce `value` across replicas.
Given a per-replica value returned by `experimental_run_v2`, say a Given a per-replica value returned by `run`, say a
per-example loss, the batch will be divided across all the replicas. This per-example loss, the batch will be divided across all the replicas. This
function allows you to aggregate across replicas and optionally also across function allows you to aggregate across replicas and optionally also across
batch elements. For example, if you have a global batch size of 8 and 2 batch elements. For example, if you have a global batch size of 8 and 2
@ -947,7 +984,7 @@ class StrategyBase(object):
Args: Args:
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
be combined. be combined.
value: A "per replica" value, e.g. returned by `experimental_run_v2` to value: A "per replica" value, e.g. returned by `run` to
be combined into a single tensor. be combined into a single tensor.
axis: Specifies the dimension to reduce along within each axis: Specifies the dimension to reduce along within each
replica's tensor. Should typically be set to the batch dimension, or replica's tensor. Should typically be set to the batch dimension, or
@ -964,7 +1001,7 @@ class StrategyBase(object):
if axis is None: if axis is None:
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
if reduce_op == reduce_util.ReduceOp.SUM: if reduce_op == reduce_util.ReduceOp.SUM:
value = self.experimental_run_v2( value = self.run(
lambda v: math_ops.reduce_sum(v, axis=axis), args=(value,)) lambda v: math_ops.reduce_sum(v, axis=axis), args=(value,))
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
if reduce_op != reduce_util.ReduceOp.MEAN: if reduce_op != reduce_util.ReduceOp.MEAN:
@ -1011,7 +1048,7 @@ class StrategyBase(object):
# reduce is complete? # reduce is complete?
return numer, denom return numer, denom
numer, denom = self.experimental_run_v2(mean_reduce_helper, args=(value,)) numer, denom = self.run(mean_reduce_helper, args=(value,))
# TODO(josh11b): Should batch reduce here instead of doing two. # TODO(josh11b): Should batch reduce here instead of doing two.
numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access
denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access
@ -1050,7 +1087,7 @@ class StrategyBase(object):
computed on that worker. computed on that worker.
Args: Args:
value: A value returned by `experimental_run()`, `experimental_run_v2()`, value: A value returned by `experimental_run()`, `run()`,
`extended.call_for_each_replica()`, or a variable created in `scope`. `extended.call_for_each_replica()`, or a variable created in `scope`.
Returns: Returns:
@ -1146,7 +1183,7 @@ class Strategy(StrategyBase):
output = strategy.experimental_assign_to_logical_device(output, 0) output = strategy.experimental_assign_to_logical_device(output, 0)
return output return output
strategy.experimental_run_v2(step_fn, args=(next(iterator),)) strategy.run(step_fn, args=(next(iterator),))
``` ```
Args: Args:
@ -1204,7 +1241,7 @@ class Strategy(StrategyBase):
output = model(inputs) output = model(inputs)
return output return output
strategy.experimental_run_v2(step_fn, args=(next(iterator),)) strategy.run(step_fn, args=(next(iterator),))
``` ```
Args: Args:
tensor: Input tensor to annotate. tensor: Input tensor to annotate.
@ -1266,7 +1303,7 @@ class Strategy(StrategyBase):
return loss return loss
strategy.experimental_run_v2(step_fn, args=(next(iterator),)) strategy.run(step_fn, args=(next(iterator),))
``` ```
Args: Args:
tensor: Input tensor to annotate. tensor: Input tensor to annotate.
@ -1280,7 +1317,7 @@ class Strategy(StrategyBase):
"""Generates `tf.distribute.DistributedValues` from `value_fn`. """Generates `tf.distribute.DistributedValues` from `value_fn`.
This function is to generate `tf.distribute.DistributedValues` to pass This function is to generate `tf.distribute.DistributedValues` to pass
into `experimental_run_v2`, `reduce`, or other methods that take into `run`, `reduce`, or other methods that take
distributed values when not using datasets. distributed values when not using datasets.
Args: Args:
@ -1468,7 +1505,7 @@ class StrategyV1(StrategyBase):
"""Runs ops in `fn` on each replica, with inputs from `input_iterator`. """Runs ops in `fn` on each replica, with inputs from `input_iterator`.
DEPRECATED: This method is not available in TF 2.x. Please switch DEPRECATED: This method is not available in TF 2.x. Please switch
to using `experimental_run_v2` instead. to using `run` instead.
When eager execution is enabled, executes ops specified by `fn` on each When eager execution is enabled, executes ops specified by `fn` on each
replica. Otherwise, builds a graph to execute the ops on each replica. replica. Otherwise, builds a graph to execute the ops on each replica.
@ -1540,10 +1577,10 @@ class StrategyExtendedV2(object):
from replica id to values. "PerReplica" is used when the value may be from replica id to values. "PerReplica" is used when the value may be
different across replicas, and "Mirrored" when the value are the same. different across replicas, and "Mirrored" when the value are the same.
* Unwrapping and merging: Consider calling a function `fn` on multiple * Unwrapping and merging: Consider calling a function `fn` on multiple
replicas, like `experimental_run_v2(fn, args=[w])` with an replicas, like `run(fn, args=[w])` with an
argument `w` that is a wrapped value. This means `w` will have a map taking argument `w` that is a wrapped value. This means `w` will have a map taking
replica id `0` to `w0`, replica id `11` to `w1`, etc. replica id `0` to `w0`, replica id `11` to `w1`, etc.
`experimental_run_v2()` unwraps `w` before calling `fn`, so `run()` unwraps `w` before calling `fn`, so
it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc. It then merges the return it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc. It then merges the return
values from `fn()`, which can possibly result in wrapped values. For values from `fn()`, which can possibly result in wrapped values. For
example, let's say `fn()` returns a tuple with three components: `(x, a, example, let's say `fn()` returns a tuple with three components: `(x, a,
@ -1573,7 +1610,7 @@ class StrategyExtendedV2(object):
* `tf.distribute.Strategy.scope`: enters cross-replica context when * `tf.distribute.Strategy.scope`: enters cross-replica context when
no other strategy is in scope. no other strategy is in scope.
* `tf.distribute.Strategy.experimental_run_v2`: calls a function in * `tf.distribute.Strategy.run`: calls a function in
replica context. replica context.
* `tf.distribute.ReplicaContext.merge_call`: transitions from replica * `tf.distribute.ReplicaContext.merge_call`: transitions from replica
context to cross-replica context. context to cross-replica context.
@ -1615,7 +1652,7 @@ class StrategyExtendedV2(object):
returned by `tf.distribute.Strategy.experimental_distribute_dataset` and returned by `tf.distribute.Strategy.experimental_distribute_dataset` and
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`. They `tf.distribute.Strategy.experimental_distribute_datasets_from_function`. They
are also the typical result returned by are also the typical result returned by
`tf.distribute.Strategy.experimental_run_v2`. You typically can't use a `tf.distribute.Strategy.run`. You typically can't use a
per-replica value directly in a cross-replica context, without first resolving per-replica value directly in a cross-replica context, without first resolving
how to aggregate the values across replicas, for instance by using how to aggregate the values across replicas, for instance by using
`tf.distribute.Strategy.reduce`. `tf.distribute.Strategy.reduce`.
@ -1653,7 +1690,7 @@ class StrategyExtendedV2(object):
The standard pattern for updating variables is to: The standard pattern for updating variables is to:
1. In your function passed to `tf.distribute.Strategy.experimental_run_v2`, 1. In your function passed to `tf.distribute.Strategy.run`,
compute a list of (update, variable) pairs. For example, the update might compute a list of (update, variable) pairs. For example, the update might
be a the gradient of the loss with respect to the variable. be a the gradient of the loss with respect to the variable.
2. Switch to cross-replica mode by calling 2. Switch to cross-replica mode by calling
@ -2011,8 +2048,7 @@ class StrategyExtendedV2(object):
"""Returns the container that this per-replica `value` belongs to. """Returns the container that this per-replica `value` belongs to.
Args: Args:
value: A value returned by `experimental_run_v2()` or a variable value: A value returned by `run()` or a variable created in `scope()`.
created in `scope()`.
Returns: Returns:
A container that `value` belongs to. A container that `value` belongs to.
@ -2157,7 +2193,7 @@ class StrategyExtendedV1(StrategyExtendedV2):
iterator, iterator,
iterations=1, iterations=1,
initial_loop_values=None): initial_loop_values=None):
"""DEPRECATED: please use `experimental_run_v2` instead. """DEPRECATED: please use `run` instead.
Run `fn` with input from `iterator` for `iterations` times. Run `fn` with input from `iterator` for `iterations` times.
@ -2233,7 +2269,7 @@ class StrategyExtendedV1(StrategyExtendedV2):
with distribution.scope(): with distribution.scope():
# in "cross-replica" context # in "cross-replica" context
... ...
merged_results = distribution.experimental_run_v2(fn, args=[3]) merged_results = distribution.run(fn, args=[3])
# merged_results has the values from every replica execution of `fn`. # merged_results has the values from every replica execution of `fn`.
# This statement prints a list: # This statement prints a list:
print(distribution.experimental_local_results(merged_results)) print(distribution.experimental_local_results(merged_results))
@ -2300,7 +2336,7 @@ class StrategyExtendedV1(StrategyExtendedV2):
# `ReplicaContext` (defined here) and `_CurrentDistributionContext` # `ReplicaContext` (defined here) and `_CurrentDistributionContext`
# (defined above) used by `tf.distribute.Strategy.scope()`: # (defined above) used by `tf.distribute.Strategy.scope()`:
# #
# * a ReplicaContext is only present during a `experimental_run_v2()` # * a ReplicaContext is only present during a `run()`
# call (except during a `merge_run` call) and in such a scope it # call (except during a `merge_run` call) and in such a scope it
# will be returned by calls to `get_replica_context()`. Implementers of new # will be returned by calls to `get_replica_context()`. Implementers of new
# Strategy descendants will frequently also need to # Strategy descendants will frequently also need to
@ -2321,7 +2357,7 @@ class ReplicaContext(object):
You can use `tf.distribute.get_replica_context` to get an instance of You can use `tf.distribute.get_replica_context` to get an instance of
`ReplicaContext`. This should be inside your replicated step function, such `ReplicaContext`. This should be inside your replicated step function, such
as in a `tf.distribute.Strategy.experimental_run_v2` call. as in a `tf.distribute.Strategy.run` call.
""" """
def __init__(self, strategy, replica_id_in_sync_group): def __init__(self, strategy, replica_id_in_sync_group):
@ -2353,11 +2389,9 @@ class ReplicaContext(object):
"""Merge args across replicas and run `merge_fn` in a cross-replica context. """Merge args across replicas and run `merge_fn` in a cross-replica context.
This allows communication and coordination when there are multiple calls This allows communication and coordination when there are multiple calls
to the step_fn triggered by a call to to the step_fn triggered by a call to `strategy.run(step_fn, ...)`.
`strategy.experimental_run_v2(step_fn, ...)`.
See `tf.distribute.Strategy.experimental_run_v2` for an See `tf.distribute.Strategy.run` for an explanation.
explanation.
If not inside a distributed scope, this is equivalent to: If not inside a distributed scope, this is equivalent to:

View File

@ -510,7 +510,7 @@ class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase):
return input_data return input_data
for _ in range(2): for _ in range(2):
default_strategy.experimental_run_v2(train_step, args=(next_val,)) default_strategy.run(train_step, args=(next_val,))
@combinations.generate(combinations.combine(mode=["graph", "eager"])) @combinations.generate(combinations.combine(mode=["graph", "eager"]))
def testDistributedDatasets(self): def testDistributedDatasets(self):

View File

@ -99,8 +99,7 @@ def get_replica_context():
will return the default `ReplicaContext` object); will return the default `ReplicaContext` object);
2. switches to cross-replica context (in which case this will return 2. switches to cross-replica context (in which case this will return
`None`) when entering a `with tf.distribute.Strategy.scope():` block; `None`) when entering a `with tf.distribute.Strategy.scope():` block;
3. switches to a (non-default) replica context inside 3. switches to a (non-default) replica context inside `strategy.run(fn, ...)`;
`strategy.experimental_run_v2(fn, ...)`;
4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
inside `merge_fn` you are back in the cross-replica context (and again inside `merge_fn` you are back in the cross-replica context (and again
this function will return `None`). this function will return `None`).
@ -121,7 +120,7 @@ def get_replica_context():
tf.print("Replica id: ", replica_context.replica_id_in_sync_group, tf.print("Replica id: ", replica_context.replica_id_in_sync_group,
" of ", replica_context.num_replicas_in_sync) " of ", replica_context.num_replicas_in_sync)
strategy.experimental_run_v2(f) strategy.run(f)
``` ```
Returns: Returns:
@ -166,7 +165,7 @@ def in_cross_replica_context():
def f(): def f():
assert not tf.distribute.in_cross_replica_context() assert not tf.distribute.in_cross_replica_context()
strategy.experimental_run_v2(f) strategy.run(f)
``` ```
Returns: Returns:

View File

@ -585,7 +585,7 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
"""Sums the `PerReplica` values in the `per_replica_features` map.""" """Sums the `PerReplica` values in the `per_replica_features` map."""
def map_fn(per_replica_values): def map_fn(per_replica_values):
per_replica_sums = distribution.experimental_run_v2( per_replica_sums = distribution.run(
(lambda x: math_ops.reduce_sum(x.values)) if all( (lambda x: math_ops.reduce_sum(x.values)) if all(
map(sparse_tensor.is_sparse, per_replica_values.values)) else map(sparse_tensor.is_sparse, per_replica_values.values)) else
math_ops.reduce_sum, (per_replica_values,)) math_ops.reduce_sum, (per_replica_values,))
@ -1048,7 +1048,7 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
@def_function.function(input_signature=[type_spec]) @def_function.function(input_signature=[type_spec])
def process_inputs(inputs): def process_inputs(inputs):
distribution.experimental_run_v2(lambda inputs: inputs, args=(inputs,)) distribution.run(lambda inputs: inputs, args=(inputs,))
for x in ds: for x in ds:
process_inputs(x) process_inputs(x)
@ -1073,7 +1073,7 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
@def_function.function(input_signature=[dist_dataset.element_spec]) @def_function.function(input_signature=[dist_dataset.element_spec])
def process_inputs(inputs): def process_inputs(inputs):
distribution.experimental_run_v2(lambda inputs: inputs, args=(inputs,)) distribution.run(lambda inputs: inputs, args=(inputs,))
for x in dist_dataset: for x in dist_dataset:
process_inputs(x) process_inputs(x)

View File

@ -97,8 +97,7 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
updates = distribution.experimental_local_results( updates = distribution.experimental_local_results(
distribution.experimental_run_v2( distribution.run(metric, args=(iterator.get_next(),)))
metric, args=(iterator.get_next(),)))
batches_per_update = distribution.num_replicas_in_sync batches_per_update = distribution.num_replicas_in_sync
self.evaluate(iterator.initializer) self.evaluate(iterator.initializer)

View File

@ -543,7 +543,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
opt.minimize(lambda: constant_op.constant(1.), []) opt.minimize(lambda: constant_op.constant(1.), [])
opt.apply_gradients([]) opt.apply_gradients([])
distribution.experimental_run_v2(run_fn) distribution.run(run_fn)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -52,7 +52,7 @@ class MirroredFunctionStrategyTest(test.TestCase):
one = constant_op.constant(1) one = constant_op.constant(1)
self.assertLen(f_traces, 0) self.assertLen(f_traces, 0)
result1 = self._strategy.experimental_run_v2(f, args=(one,)) result1 = self._strategy.run(f, args=(one,))
self.assertLen(f_traces, 1) # Function traced once, not for each replica. self.assertLen(f_traces, 1) # Function traced once, not for each replica.
# Returns a per-replica value. # Returns a per-replica value.
self.assertIsInstance(result1, values.PerReplica) self.assertIsInstance(result1, values.PerReplica)
@ -60,7 +60,7 @@ class MirroredFunctionStrategyTest(test.TestCase):
self._strategy.experimental_local_results(result1)) self._strategy.experimental_local_results(result1))
# Try passing a per-replica value as an argument. # Try passing a per-replica value as an argument.
result2 = self._strategy.experimental_run_v2(f, args=(result1,)) result2 = self._strategy.run(f, args=(result1,))
self.assertLen(f_traces, 1) self.assertLen(f_traces, 1)
self.assertIsInstance(result2, values.PerReplica) self.assertIsInstance(result2, values.PerReplica)
self.assertAllEqual([1, 3], self.assertAllEqual([1, 3],
@ -88,7 +88,7 @@ class MirroredFunctionStrategyTest(test.TestCase):
one = constant_op.constant(1) one = constant_op.constant(1)
self.assertLen(f_traces, 0) self.assertLen(f_traces, 0)
self.assertLen(g_traces, 0) self.assertLen(g_traces, 0)
result = self._strategy.experimental_run_v2(f, args=(one,)) result = self._strategy.run(f, args=(one,))
# Functions traced once, not for each replica. # Functions traced once, not for each replica.
self.assertLen(f_traces, 1) self.assertLen(f_traces, 1)
self.assertLen(g_traces, 1) self.assertLen(g_traces, 1)

View File

@ -403,8 +403,7 @@ class MirroredStrategy(distribute_lib.Strategy):
total_result = 0 total_result = 0
for x in dataset: for x in dataset:
per_replica_result = my_strategy.experimental_run_v2(replica_fn, per_replica_result = my_strategy.run(replica_fn, args=(x,))
args=(x,))
total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
per_replica_result, axis=None) per_replica_result, axis=None)
return total_result return total_result
@ -752,13 +751,13 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
return wrapped(args, kwargs) return wrapped(args, kwargs)
if context.executing_eagerly(): if context.executing_eagerly():
logging.log_first_n(logging.WARN, "Using %s eagerly has significant " logging.log_first_n(
"overhead currently. We will be working on improving " logging.WARN, "Using %s eagerly has significant "
"this in the future, but for now please wrap " "overhead currently. We will be working on improving "
"`call_for_each_replica` or `experimental_run` or " "this in the future, but for now please wrap "
"`experimental_run_v2` inside a tf.function to get " "`call_for_each_replica` or `experimental_run` or "
"the best performance." % "`run` inside a tf.function to get the best performance." %
self._container_strategy().__class__.__name__, 5) self._container_strategy().__class__.__name__, 5)
else: else:
# When a tf.function is wrapped to trigger _call_for_each_replica (see # When a tf.function is wrapped to trigger _call_for_each_replica (see
# the other branch above), AutoGraph stops conversion at # the other branch above), AutoGraph stops conversion at

View File

@ -1368,7 +1368,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
return t.gradient(loss, [w, b]) return t.gradient(loss, [w, b])
def step_fn(): def step_fn():
return distribution.experimental_run_v2(replica_fn) return distribution.run(replica_fn)
context.enable_run_metadata() context.enable_run_metadata()
g1, g2 = step_fn() g1, g2 = step_fn()
@ -1399,7 +1399,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def replica_fn(): def replica_fn():
return f() return f()
distribution.experimental_run_v2(replica_fn) distribution.run(replica_fn)
def _replica_id(): def _replica_id():

View File

@ -192,7 +192,7 @@ class ExponentialMovingAverageTest(test.TestCase, parameterized.TestCase):
ema.apply([w]) ema.apply([w])
return ema.average(w) return ema.average(w)
return distribution.experimental_run_v2(_ema_replica_fn_eager) return distribution.run(_ema_replica_fn_eager)
if use_function: if use_function:
fn = def_function.function(fn) fn = def_function.function(fn)
@ -238,7 +238,7 @@ class ExponentialMovingAverageTest(test.TestCase, parameterized.TestCase):
self.skipTest("b/139550827: Cannot do variable.assign in replica context " self.skipTest("b/139550827: Cannot do variable.assign in replica context "
"of TPUStrategy") "of TPUStrategy")
with distribution.scope(): with distribution.scope():
w_assign, w_apply, ema_w = distribution.experimental_run_v2( w_assign, w_apply, ema_w = distribution.run(
self._ema_replica_fn_graph) self._ema_replica_fn_graph)
self.assertEqual(ema_w.name, "w/ExponentialMovingAverage:0") self.assertEqual(ema_w.name, "w/ExponentialMovingAverage:0")
with self.cached_session(): with self.cached_session():

View File

@ -44,7 +44,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
Using this strategy will place any variables created in its scope on the Using this strategy will place any variables created in its scope on the
specified device. Input distributed through this strategy will be specified device. Input distributed through this strategy will be
prefetched to the specified device. Moreover, any functions called via prefetched to the specified device. Moreover, any functions called via
`strategy.experimental_run_v2` will also be placed on the specified device `strategy.run` will also be placed on the specified device
as well. as well.
Typical usage of this strategy could be testing your code with the Typical usage of this strategy could be testing your code with the
@ -64,7 +64,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
result = 0 result = 0
for i in range(10): for i in range(10):
result += strategy.experimental_run_v2(step_fn, args=(i,)) result += strategy.run(step_fn, args=(i,))
print(result) # 90 print(result) # 90
``` ```
""" """
@ -127,7 +127,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn) inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn)
for batch in inputs: for batch in inputs:
replica_results = strategy.experimental_run_v2(replica_fn, args=(batch,)) replica_results = strategy.run(replica_fn, args=(batch,))
``` ```
IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
@ -154,7 +154,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
value, so the result is just the value in a tuple. value, so the result is just the value in a tuple.
Args: Args:
value: A value returned by `experimental_run()`, `experimental_run_v2()`, value: A value returned by `experimental_run()`, `run()`,
`extended.call_for_each_replica()`, or a variable created in `scope`. `extended.call_for_each_replica()`, or a variable created in `scope`.
Returns: Returns:
@ -163,7 +163,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
""" """
return super(OneDeviceStrategy, self).experimental_local_results(value) return super(OneDeviceStrategy, self).experimental_local_results(value)
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation def run(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation
"""Run `fn` on each replica, with the given arguments. """Run `fn` on each replica, with the given arguments.
In `OneDeviceStrategy`, `fn` is simply called within a device scope for the In `OneDeviceStrategy`, `fn` is simply called within a device scope for the
@ -179,8 +179,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
Returns: Returns:
Return value from running `fn`. Return value from running `fn`.
""" """
return super(OneDeviceStrategy, return super(OneDeviceStrategy, self).run(fn, args, kwargs, options)
self).experimental_run_v2(fn, args, kwargs, options)
def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation
"""Reduce `value` across replicas. """Reduce `value` across replicas.
@ -203,7 +202,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
Args: Args:
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
be combined. be combined.
value: A "per replica" value, e.g. returned by `experimental_run_v2` to value: A "per replica" value, e.g. returned by `run` to
be combined into a single tensor. be combined into a single tensor.
axis: Specifies the dimension to reduce along within each axis: Specifies the dimension to reduce along within each
replica's tensor. Should typically be set to the batch dimension, or replica's tensor. Should typically be set to the batch dimension, or
@ -309,7 +308,7 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
def _experimental_distribute_values_from_function(self, value_fn): def _experimental_distribute_values_from_function(self, value_fn):
# TODO(b/137795644): This should return a PerReplica value but other # TODO(b/137795644): This should return a PerReplica value but other
# methods like experimental_run_v2 in OneDeviceStrategy need to be modified # methods like run in OneDeviceStrategy need to be modified
# to do the same. # to do the same.
return value_fn(distribute_lib.ValueContext()) return value_fn(distribute_lib.ValueContext())

View File

@ -85,8 +85,7 @@ class SavedModelTFModuleTest(test_base.TestSavedModelBase):
dist_predict_dataset = distribution.experimental_distribute_dataset( dist_predict_dataset = distribution.experimental_distribute_dataset(
predict_dataset) predict_dataset)
per_replica_predict_data = next(iter(dist_predict_dataset)) per_replica_predict_data = next(iter(dist_predict_dataset))
result = distribution.experimental_run_v2( result = distribution.run(model, args=(per_replica_predict_data,))
model, args=(per_replica_predict_data,))
# Convert the per_replica value to a list, then concatenate them # Convert the per_replica value to a list, then concatenate them
reduced = distribution.experimental_local_results(result) reduced = distribution.experimental_local_results(result)
concat = array_ops.concat(reduced, 0) concat = array_ops.concat(reduced, 0)

View File

@ -112,7 +112,7 @@ def load_and_run_with_saved_model_api(distribution, saved_dir, predict_dataset,
dist_predict_dataset = distribution.experimental_distribute_dataset( dist_predict_dataset = distribution.experimental_distribute_dataset(
predict_dataset) predict_dataset)
per_replica_predict_data = next(iter(dist_predict_dataset)) per_replica_predict_data = next(iter(dist_predict_dataset))
result = distribution.experimental_run_v2( result = distribution.run(
func.signatures[_DEFAULT_FUNCTION_KEY], func.signatures[_DEFAULT_FUNCTION_KEY],
args=(per_replica_predict_data,)) args=(per_replica_predict_data,))
result = result[output_name] result = result[output_name]

View File

@ -56,8 +56,7 @@ class StrategyCombinationsTest(test.TestCase, parameterized.TestCase):
mode=["graph", "eager"])) mode=["graph", "eager"]))
def testMirrored2CPUs(self, distribution): def testMirrored2CPUs(self, distribution):
with distribution.scope(): with distribution.scope():
one_per_replica = distribution.experimental_run_v2( one_per_replica = distribution.run(lambda: constant_op.constant(1))
lambda: constant_op.constant(1))
num_replicas = distribution.reduce( num_replicas = distribution.reduce(
reduce_util.ReduceOp.SUM, one_per_replica, axis=None) reduce_util.ReduceOp.SUM, one_per_replica, axis=None)
self.assertEqual(2, self.evaluate(num_replicas)) self.assertEqual(2, self.evaluate(num_replicas))

View File

@ -453,16 +453,15 @@ class OneDeviceDistributionTestBase(test.TestCase):
"""Some tests that should work with any one-device DistributionStrategy.""" """Some tests that should work with any one-device DistributionStrategy."""
def _test_run(self, strategy): def _test_run(self, strategy):
out1 = strategy.experimental_run_v2(lambda: constant_op.constant(4.)) out1 = strategy.run(lambda: constant_op.constant(4.))
self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1))) self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1)))
out2 = strategy.experimental_run_v2( out2 = strategy.run(lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2)) out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
self.assertAllEqual([8.], out2_vals["a"]) self.assertAllEqual([8.], out2_vals["a"])
self.assertAllEqual([16.], out2_vals["b"]) self.assertAllEqual([16.], out2_vals["b"])
out3 = strategy.experimental_run_v2(lambda b, a: a + 2 * b + 2, kwargs=out2) out3 = strategy.run(lambda b, a: a + 2 * b + 2, kwargs=out2)
self.assertAllEqual([42.], self.evaluate(strategy.unwrap(out3))) self.assertAllEqual([42.], self.evaluate(strategy.unwrap(out3)))
def _test_all_reduce_sum(self, strategy): def _test_all_reduce_sum(self, strategy):
@ -575,17 +574,16 @@ class TwoDeviceDistributionTestBase(test.TestCase):
"""Some tests that should work with any two-device DistributionStrategy.""" """Some tests that should work with any two-device DistributionStrategy."""
def _test_run(self, strategy): def _test_run(self, strategy):
out1 = strategy.experimental_run_v2( out1 = strategy.run(
lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1) lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1)
self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1))) self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1)))
out2 = strategy.experimental_run_v2( out2 = strategy.run(lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2)) out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
self.assertAllEqual([2, 4], out2_vals["a"]) self.assertAllEqual([2, 4], out2_vals["a"])
self.assertAllEqual([1, 4], out2_vals["b"]) self.assertAllEqual([1, 4], out2_vals["b"])
out3 = strategy.experimental_run_v2(lambda b, a: a + 2 * b + 2, kwargs=out2) out3 = strategy.run(lambda b, a: a + 2 * b + 2, kwargs=out2)
self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3))) self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3)))
def _test_all_reduce_sum(self, strategy): def _test_all_reduce_sum(self, strategy):

View File

@ -85,28 +85,28 @@ def maybe_init_scope():
yield yield
def validate_experimental_run_function(fn): def validate_run_function(fn):
"""Validate the function passed into strategy.experimental_run_v2.""" """Validate the function passed into strategy.run."""
# We allow three types of functions/objects passed into TPUStrategy # We allow three types of functions/objects passed into TPUStrategy
# experimental_run_v2 in eager mode: # run in eager mode:
# 1. a user annotated tf.function # 1. a user annotated tf.function
# 2. a ConcreteFunction, this is mostly what you get from loading a saved # 2. a ConcreteFunction, this is mostly what you get from loading a saved
# model. # model.
# 3. a callable object and the `__call__` method itself is a tf.function. # 3. a callable object and the `__call__` method itself is a tf.function.
# #
# Otherwise we return an error, because we don't support eagerly running # Otherwise we return an error, because we don't support eagerly running
# experimental_run_v2 in TPUStrategy. # run in TPUStrategy.
if context.executing_eagerly() and not isinstance( if context.executing_eagerly() \
fn, def_function.Function) and not isinstance( and not isinstance(fn, def_function.Function) \
fn, function.ConcreteFunction) and not (callable(fn) and isinstance( and not isinstance(fn, function.ConcreteFunction) \
fn.__call__, def_function.Function)): and not (callable(fn) and isinstance(fn.__call__, def_function.Function)):
raise NotImplementedError( raise NotImplementedError(
"TPUStrategy.experimental_run_v2(fn, ...) does not support pure eager " "TPUStrategy.run(fn, ...) does not support pure eager "
"execution. please make sure the function passed into " "execution. please make sure the function passed into "
"`strategy.experimental_run_v2` is a `tf.function` or " "`strategy.run` is a `tf.function` or "
"`strategy.experimental_run_v2` is called inside a `tf.function` if " "`strategy.run` is called inside a `tf.function` if "
"eager behavior is enabled.") "eager behavior is enabled.")
@ -135,10 +135,10 @@ class TPUStrategy(distribute_lib.Strategy):
To run TF2 programs on TPUs, you can either use `.compile` and To run TF2 programs on TPUs, you can either use `.compile` and
`.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
training loop by calling `strategy.experimental_run_v2` directly. Note that training loop by calling `strategy.run` directly. Note that
TPUStrategy doesn't support pure eager execution, so please make sure the TPUStrategy doesn't support pure eager execution, so please make sure the
function passed into `strategy.experimental_run_v2` is a `tf.function` or function passed into `strategy.run` is a `tf.function` or
`strategy.experimental_run_v2` is called inside a `tf.function` if eager `strategy.run` is called inside a `tf.function` if eager
behavior is enabled. behavior is enabled.
Args: Args:
@ -159,9 +159,9 @@ class TPUStrategy(distribute_lib.Strategy):
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation. # can use the default implementation.
# This implementation runs a single step. It does not use infeed or outfeed. # This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): def run(self, fn, args=(), kwargs=None, options=None):
"""See base class.""" """See base class."""
validate_experimental_run_function(fn) validate_run_function(fn)
# Note: the target function is converted to graph even when in Eager mode, # Note: the target function is converted to graph even when in Eager mode,
# so autograph is on by default here. # so autograph is on by default here.
@ -208,7 +208,7 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation. # can use the default implementation.
# This implementation runs a single step. It does not use infeed or outfeed. # This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): def run(self, fn, args=(), kwargs=None, options=None):
"""Run `fn` on each replica, with the given arguments. """Run `fn` on each replica, with the given arguments.
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
@ -223,7 +223,7 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
per-replica objects containing tensors or composite tensors. per-replica objects containing tensors or composite tensors.
Users can pass strategy specific options to `options` argument. An example Users can pass strategy specific options to `options` argument. An example
to enable bucketizing dynamic shapes in `TPUStrategy.experimental_run_v2` to enable bucketizing dynamic shapes in `TPUStrategy.run`
is: is:
```python ```python
@ -242,7 +242,7 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
output = tf.reduce_sum(inputs) output = tf.reduce_sum(inputs)
return output return output
strategy.experimental_run_v2(step_fn, args=(next(iterator),), strategy.run(step_fn, args=(next(iterator),),
options=options) options=options)
``` ```
@ -259,7 +259,7 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
structure can either be "per-replica" `Tensor` objects or `Tensor`s structure can either be "per-replica" `Tensor` objects or `Tensor`s
(for example, if running on a single replica). (for example, if running on a single replica).
""" """
validate_experimental_run_function(fn) validate_run_function(fn)
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
options = options or distribute_lib.RunOptions() options = options or distribute_lib.RunOptions()

View File

@ -90,8 +90,8 @@ class TPUStrategyTest(test.TestCase):
@def_function.function @def_function.function
def train_step(): def train_step():
outputs = strategy.experimental_local_results( outputs = strategy.experimental_local_results(
strategy.experimental_run_v2(computation, args=([2., 2.],))) strategy.run(computation, args=([2., 2.],)))
outputs2 = strategy2.experimental_run_v2( outputs2 = strategy2.run(
computation, args=([outputs[0]],)) computation, args=([outputs[0]],))
return outputs2 return outputs2
@ -181,9 +181,9 @@ class TPUStrategyTest(test.TestCase):
def step_fn(): def step_fn():
return v + 1.0 return v + 1.0
all_core_strategy.experimental_run_v2(step_fn) all_core_strategy.run(step_fn)
r1 = first_core_strategy.experimental_run_v2(step_fn) r1 = first_core_strategy.run(step_fn)
r2 = second_core_strategy.experimental_run_v2(step_fn) r2 = second_core_strategy.run(step_fn)
return r1 + r2 return r1 + r2
train_step() train_step()

View File

@ -61,13 +61,13 @@ class DistributedValues(object):
A subclass instance of DistributedValues is created when creating variables A subclass instance of DistributedValues is created when creating variables
within a distribution strategy, iterating a `tf.Dataset` or through within a distribution strategy, iterating a `tf.Dataset` or through
`strategy.experimental_run_v2`. This base class should never be instantiated `strategy.run`. This base class should never be instantiated
directly. DistributedValues contains a value per replica. Depending on directly. DistributedValues contains a value per replica. Depending on
the subclass, the values could either be synced on update, synced on demand, the subclass, the values could either be synced on update, synced on demand,
or never synced. or never synced.
DistributedValues can be reduced to obtain single value across replicas, DistributedValues can be reduced to obtain single value across replicas,
as input into `experimental_run_v2` or the per replica values inspected as input into `run` or the per replica values inspected
using `experimental_local_results`. using `experimental_local_results`.
Example usage: Example usage:
@ -79,16 +79,16 @@ class DistributedValues(object):
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> distributed_values = next(dataset_iterator) >>> distributed_values = next(dataset_iterator)
2. Returned by `experimental_run_v2`: 2. Returned by `run`:
>>> strategy = tf.distribute.MirroredStrategy() >>> strategy = tf.distribute.MirroredStrategy()
>>> @tf.function >>> @tf.function
... def run(): ... def run():
... ctx = tf.distribute.get_replica_context() ... ctx = tf.distribute.get_replica_context()
... return ctx.replica_id_in_sync_group ... return ctx.replica_id_in_sync_group
>>> distributed_values = strategy.experimental_run_v2(run) >>> distributed_values = strategy.run(run)
3. As input into `experimental_run_v2`: 3. As input into `run`:
>>> strategy = tf.distribute.MirroredStrategy() >>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
@ -96,7 +96,7 @@ class DistributedValues(object):
>>> @tf.function >>> @tf.function
... def run(input): ... def run(input):
... return input + 1.0 ... return input + 1.0
>>> updated_value = strategy.experimental_run_v2(run, >>> updated_value = strategy.run(run,
... args=(distributed_values,)) ... args=(distributed_values,))
4. Reduce value 4. Reduce value

View File

@ -215,8 +215,8 @@ class DistributedValuesTest(test.TestCase, parameterized.TestCase):
return math_ops.square(x) return math_ops.square(x)
outputs = distribution.experimental_local_results( outputs = distribution.experimental_local_results(
distribution.experimental_run_v2(computation, distribution.run(computation,
args=(distributed_values,))) args=(distributed_values,)))
return outputs return outputs
local_results = run() local_results = run()
@ -740,7 +740,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
results = self.evaluate( results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(f))) distribution.run(f)))
for value in results: for value in results:
self.assertEqual(2., value) self.assertEqual(2., value)
@ -798,7 +798,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
"Cannot update non-float variables"): "Cannot update non-float variables"):
self.evaluate( self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(assign))) distribution.run(assign)))
# allow assign() with same value in replica context. # allow assign() with same value in replica context.
@def_function.function @def_function.function
@ -807,7 +807,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
self.evaluate( self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(assign_same))) distribution.run(assign_same)))
self.assertEqual(self.evaluate(v.read_value()), 2) self.assertEqual(self.evaluate(v.read_value()), 2)
# allow assign() with mirrored variable in replica context. # allow assign() with mirrored variable in replica context.
@ -824,7 +824,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
self.evaluate( self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(assign_mirrored))) distribution.run(assign_mirrored)))
self.assertEqual(self.evaluate(v.read_value()), 3) self.assertEqual(self.evaluate(v.read_value()), 3)
# allow assign() in cross replica context. # allow assign() in cross replica context.
@ -912,7 +912,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
def f(): def f():
if v[0] is None: if v[0] is None:
v[0] = variables_lib.Variable(random_ops.random_normal([])) v[0] = variables_lib.Variable(random_ops.random_normal([]))
distribution.experimental_run_v2(f)
distribution.run(f)
context.set_global_seed(None) context.set_global_seed(None)
step() step()
@ -953,7 +954,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
@def_function.function @def_function.function
def foo(): def foo():
distribution.experimental_run_v2(replica_fn) distribution.run(replica_fn)
foo() foo()
@ -980,7 +981,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
replica_id = ctx.replica_id_in_sync_group replica_id = ctx.replica_id_in_sync_group
return v.assign(math_ops.cast(replica_id, dtypes.float32)) return v.assign(math_ops.cast(replica_id, dtypes.float32))
per_replica_results = self.evaluate(distribution.experimental_local_results( per_replica_results = self.evaluate(distribution.experimental_local_results(
distribution.experimental_run_v2(assign))) distribution.run(assign)))
# The per-replica values should always match the first replicas value. # The per-replica values should always match the first replicas value.
self.assertAllEqual( self.assertAllEqual(
array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32), array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
@ -1006,7 +1007,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
per_replica_results = self.evaluate( per_replica_results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(assign))) distribution.run(assign)))
# The per-replica values should always match the first replicas value. # The per-replica values should always match the first replicas value.
self.assertAllEqual([3, 3], per_replica_results) self.assertAllEqual([3, 3], per_replica_results)
@ -1037,7 +1038,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
per_replica_results = self.evaluate( per_replica_results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(scatter_sub))) distribution.run(scatter_sub)))
self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results) self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results)
@combinations.generate( @combinations.generate(
@ -1064,7 +1065,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
per_replica_results = self.evaluate( per_replica_results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(scatter_add))) distribution.run(scatter_add)))
self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results) self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results)
@combinations.generate( @combinations.generate(
@ -1091,7 +1092,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
per_replica_results = self.evaluate( per_replica_results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(scatter_div))) distribution.run(scatter_div)))
self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results) self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results)
@combinations.generate( @combinations.generate(
@ -1119,7 +1120,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
per_replica_results = self.evaluate( per_replica_results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(scatter_mul))) distribution.run(scatter_mul)))
self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results) self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results)
@combinations.generate( @combinations.generate(
@ -1148,11 +1149,11 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"): with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"):
self.evaluate( self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(scatter_min, args=(v1,)))) distribution.run(scatter_min, args=(v1,))))
per_replica_results = self.evaluate( per_replica_results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(scatter_min, args=(v2,)))) distribution.run(scatter_min, args=(v2,))))
self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results) self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results)
@combinations.generate( @combinations.generate(
@ -1181,11 +1182,11 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"): with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"):
self.evaluate( self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(scatter_max, args=(v1,)))) distribution.run(scatter_max, args=(v1,))))
per_replica_results = self.evaluate( per_replica_results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(scatter_max, args=(v2,)))) distribution.run(scatter_max, args=(v2,))))
self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results) self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results)
@combinations.generate( @combinations.generate(
@ -1214,11 +1215,11 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"): with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"):
self.evaluate( self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(scatter_update, args=(v1,)))) distribution.run(scatter_update, args=(v1,))))
per_replica_results = self.evaluate( per_replica_results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(scatter_update, args=(v2,)))) distribution.run(scatter_update, args=(v2,))))
self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results) self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results)
@combinations.generate( @combinations.generate(
@ -1314,7 +1315,7 @@ def mirrored_and_tpu_strategy_combinations():
# tests. # tests.
def strategy_and_run_tf_function_combinations(): def strategy_and_run_tf_function_combinations():
# Test the combination of different strategies and whether a tf.function # Test the combination of different strategies and whether a tf.function
# is passed into strategy.experimental_run_v2.""" # is passed into strategy.run."""
return combinations.combine( return combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -1538,7 +1539,8 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
if experimental_run_tf_function: if experimental_run_tf_function:
update_fn = def_function.function(update_fn) update_fn = def_function.function(update_fn)
return distribution.experimental_local_results( return distribution.experimental_local_results(
distribution.experimental_run_v2(update_fn)) distribution.run(update_fn))
updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)] updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
aggregations = [ aggregations = [
variables_lib.VariableAggregation.NONE, variables_lib.VariableAggregation.NONE,
@ -1574,7 +1576,8 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
if experimental_run_tf_function: if experimental_run_tf_function:
update_fn = def_function.function(update_fn) update_fn = def_function.function(update_fn)
return distribution.experimental_local_results( return distribution.experimental_local_results(
distribution.experimental_run_v2(update_fn)) distribution.run(update_fn))
updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)] updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
aggregations = [ aggregations = [
variables_lib.VariableAggregation.NONE, variables_lib.VariableAggregation.NONE,
@ -1648,7 +1651,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
read_var_fn = v.read_value read_var_fn = v.read_value
results = self.evaluate( results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(
distribution.experimental_run_v2(read_var_fn))) distribution.run(read_var_fn)))
for component, value in zip(v._values, results): for component, value in zip(v._values, results):
self.assertAllEqual(self.evaluate(component.read_value()), value) self.assertAllEqual(self.evaluate(component.read_value()), value)
@ -1679,8 +1682,8 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
if experimental_run_tf_function: if experimental_run_tf_function:
assign = def_function.function(assign) assign = def_function.function(assign)
self.evaluate(distribution.experimental_local_results( self.evaluate(
distribution.experimental_run_v2(assign))) distribution.experimental_local_results(distribution.run(assign)))
num_replicas = distribution.num_replicas_in_sync num_replicas = distribution.num_replicas_in_sync
sum_of_replica_values = num_replicas * (num_replicas - 1) / 2. sum_of_replica_values = num_replicas * (num_replicas - 1) / 2.
if aggregation == variables_lib.VariableAggregation.SUM: if aggregation == variables_lib.VariableAggregation.SUM:
@ -1717,8 +1720,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
all_reduce = def_function.function(all_reduce) all_reduce = def_function.function(all_reduce)
per_replica_results = self.evaluate( per_replica_results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(distribution.run(all_reduce)))
distribution.experimental_run_v2(all_reduce)))
expected_result = [] expected_result = []
for i in range(distribution.num_replicas_in_sync): for i in range(distribution.num_replicas_in_sync):
expected_result.append(2.0 * distribution.num_replicas_in_sync + expected_result.append(2.0 * distribution.num_replicas_in_sync +
@ -1750,8 +1752,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
assign = def_function.function(assign) assign = def_function.function(assign)
per_replica_results = self.evaluate( per_replica_results = self.evaluate(
distribution.experimental_local_results( distribution.experimental_local_results(distribution.run(assign)))
distribution.experimental_run_v2(assign)))
expected_result = [] expected_result = []
for i in range(distribution.num_replicas_in_sync): for i in range(distribution.num_replicas_in_sync):
expected_result.append(1.0 * i) expected_result.append(1.0 * i)
@ -1781,7 +1782,8 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
v[0] = variables_lib.Variable( v[0] = variables_lib.Variable(
random_ops.random_normal([]), random_ops.random_normal([]),
synchronization=variables_lib.VariableSynchronization.ON_READ) synchronization=variables_lib.VariableSynchronization.ON_READ)
distribution.experimental_run_v2(f)
distribution.run(f)
context.set_global_seed(None) context.set_global_seed(None)
step() step()

View File

@ -134,8 +134,7 @@ class NormalizationTest(test.TestCase, parameterized.TestCase):
optimizer.apply_gradients(zip(grads, bn.variables)) optimizer.apply_gradients(zip(grads, bn.variables))
return loss return loss
return distribution.experimental_run_v2( return distribution.run(step_fn, args=(inputs, targets))
step_fn, args=(inputs, targets))
for _ in range(100): for _ in range(100):
np_output = train_step().numpy() np_output = train_step().numpy()
@ -153,8 +152,7 @@ class NormalizationTest(test.TestCase, parameterized.TestCase):
outputs = bn.apply(inputs, training=False) outputs = bn.apply(inputs, training=False)
return outputs return outputs
return distribution.experimental_run_v2( return distribution.run(step_fn, args=(inputs,))
step_fn, args=(inputs,))
# Test inference. # Test inference.
self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32), self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32),

View File

@ -958,8 +958,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
strategy = mirrored_strategy.MirroredStrategy() strategy = mirrored_strategy.MirroredStrategy()
with strategy.scope(): with strategy.scope():
v = variables.Variable([1., 2., 3.]) v = variables.Variable([1., 2., 3.])
strategy.experimental_run_v2( strategy.run(_replicated, args=(constant_op.constant([.1, -.2, .3]),))
_replicated, args=(constant_op.constant([.1, -.2, .3]),))
# TODO(b/141025187): Add a no_new_pyobjects decorator. # TODO(b/141025187): Add a no_new_pyobjects decorator.
def testArgumentUnused(self): def testArgumentUnused(self):

View File

@ -533,7 +533,8 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
return grad_v1, grad_v2 return grad_v1, grad_v2
if context.executing_eagerly(): if context.executing_eagerly():
run_fn = def_function.function(run_fn) run_fn = def_function.function(run_fn)
grad_v1, grad_v2 = distribution.experimental_run_v2(run_fn)
grad_v1, grad_v2 = distribution.run(run_fn)
self.assertIsNotNone(grad_v1) self.assertIsNotNone(grad_v1)
self.assertIsNotNone(grad_v2) self.assertIsNotNone(grad_v2)
@ -2057,8 +2058,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
optimizer.apply_gradients(zip(grads, model.trainable_variables)) optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss return loss
per_replica_losses = distribution.experimental_run_v2( per_replica_losses = distribution.run(step_fn, args=(dist_inputs,))
step_fn, args=(dist_inputs,))
return distribution.reduce( return distribution.reduce(
reduce_util.ReduceOp.SUM, per_replica_losses, axis=None) reduce_util.ReduceOp.SUM, per_replica_losses, axis=None)

View File

@ -863,8 +863,7 @@ def _make_execution_function_without_cloning(model, mode):
# PerReplicas as arguments. On every replica inside this call, each # PerReplicas as arguments. On every replica inside this call, each
# PerReplica object will return the value for that replica. The outputs # PerReplica object will return the value for that replica. The outputs
# are PerReplicas too. # are PerReplicas too.
outputs = strategy.experimental_run_v2( outputs = strategy.run(per_replica_function, args=(x, y, sample_weights))
per_replica_function, args=(x, y, sample_weights))
# Out of PerReplica outputs reduce or pick values to return. # Out of PerReplica outputs reduce or pick values to return.
all_outputs = unwrap_outputs( all_outputs = unwrap_outputs(
strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT)) strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))

View File

@ -792,7 +792,7 @@ class GeneratorDataAdapter(DataAdapter):
# Need to build the Model on concrete input shapes. # Need to build the Model on concrete input shapes.
if model is not None and not model.built: if model is not None and not model.built:
concrete_x, _, _ = unpack_x_y_sample_weight(peek) concrete_x, _, _ = unpack_x_y_sample_weight(peek)
model.distribute_strategy.experimental_run_v2( model.distribute_strategy.run(
lambda x: model(x, training=False), args=(concrete_x,)) lambda x: model(x, training=False), args=(concrete_x,))
self._first_batch_size = int(nest.flatten(peek)[0].shape[0]) self._first_batch_size = int(nest.flatten(peek)[0].shape[0])

View File

@ -500,7 +500,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
def train_function(iterator): def train_function(iterator):
data = next(iterator) data = next(iterator)
outputs = self.distribute_strategy.experimental_run_v2( outputs = self.distribute_strategy.run(
self.train_step, args=(data,)) self.train_step, args=(data,))
outputs = reduce_per_replica( outputs = reduce_per_replica(
outputs, self.distribute_strategy, reduction='first') outputs, self.distribute_strategy, reduction='first')
@ -873,7 +873,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
def test_function(iterator): def test_function(iterator):
data = next(iterator) data = next(iterator)
outputs = self.distribute_strategy.experimental_run_v2( outputs = self.distribute_strategy.run(
self.test_step, args=(data,)) self.test_step, args=(data,))
outputs = reduce_per_replica( outputs = reduce_per_replica(
outputs, self.distribute_strategy, reduction='first') outputs, self.distribute_strategy, reduction='first')
@ -1079,7 +1079,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
def predict_function(iterator): def predict_function(iterator):
data = next(iterator) data = next(iterator)
outputs = self.distribute_strategy.experimental_run_v2( outputs = self.distribute_strategy.run(
self.predict_step, args=(data,)) self.predict_step, args=(data,))
outputs = reduce_per_replica( outputs = reduce_per_replica(
outputs, self.distribute_strategy, reduction='concat') outputs, self.distribute_strategy, reduction='concat')

View File

@ -338,7 +338,7 @@ def experimental_tpu_test_loop(model,
return [array_ops.identity(out) for out in outputs] return [array_ops.identity(out) for out in outputs]
test_input_data = iterator.get_next() test_input_data = iterator.get_next()
per_replica_outputs = current_strategy.experimental_run_v2( per_replica_outputs = current_strategy.run(
_test_step_fn, args=(test_input_data,)) _test_step_fn, args=(test_input_data,))
output_tensors = {} output_tensors = {}
for label, output in zip(out_labels, per_replica_outputs): for label, output in zip(out_labels, per_replica_outputs):
@ -488,7 +488,7 @@ def experimental_tpu_predict_loop(model,
# use numpy arrays directly to avoid cumulating unnecessary input pipeline # use numpy arrays directly to avoid cumulating unnecessary input pipeline
# ops. # ops.
predict_input_data = iterator.get_next() predict_input_data = iterator.get_next()
per_replica_outputs = current_strategy.experimental_run_v2( per_replica_outputs = current_strategy.run(
_predict_step_fn, args=(predict_input_data,)) _predict_step_fn, args=(predict_input_data,))
output_tensors = dist_utils.flatten_per_replica_values( output_tensors = dist_utils.flatten_per_replica_values(
current_strategy, per_replica_outputs) current_strategy, per_replica_outputs)

View File

@ -46,7 +46,7 @@ class TemplateMirroredStrategyTest(test.TestCase):
strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"]) strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
out = strategy.experimental_local_results( out = strategy.experimental_local_results(
strategy.experimental_run_v2(temp)) strategy.run(temp))
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([42., 42.], self.evaluate(out)) self.assertAllEqual([42., 42.], self.evaluate(out))

View File

@ -57,7 +57,7 @@ class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase):
# With strategy - num replicas = 2 # With strategy - num replicas = 2
with distribution.scope(): with distribution.scope():
per_replica_losses = distribution.experimental_run_v2( per_replica_losses = distribution.run(
nn_impl.compute_average_loss, args=(per_example_loss,)) nn_impl.compute_average_loss, args=(per_example_loss,))
loss = distribution.reduce("SUM", per_replica_losses, axis=None) loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.) / 3) self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.) / 3)
@ -71,7 +71,7 @@ class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase):
def testComputeAverageLossSampleWeights(self, distribution): def testComputeAverageLossSampleWeights(self, distribution):
with distribution.scope(): with distribution.scope():
# Scalar sample weight # Scalar sample weight
per_replica_losses = distribution.experimental_run_v2( per_replica_losses = distribution.run(
nn_impl.compute_average_loss, nn_impl.compute_average_loss,
args=([2., 4., 6.],), args=([2., 4., 6.],),
kwargs={"sample_weight": 2}) kwargs={"sample_weight": 2})
@ -79,7 +79,7 @@ class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase):
self.assertAllClose(self.evaluate(loss), (2. + 4. + 6.) * 2. / 3) self.assertAllClose(self.evaluate(loss), (2. + 4. + 6.) * 2. / 3)
# Per example sample weight # Per example sample weight
per_replica_losses = distribution.experimental_run_v2( per_replica_losses = distribution.run(
nn_impl.compute_average_loss, nn_impl.compute_average_loss,
args=([2., 4., 6.],), args=([2., 4., 6.],),
kwargs={"sample_weight": [0.3, 0.5, 0.2]}) kwargs={"sample_weight": [0.3, 0.5, 0.2]})
@ -88,7 +88,7 @@ class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase):
self.evaluate(loss), (2. * 0.3 + 4. * 0.5 + 6. * 0.2) / 3) self.evaluate(loss), (2. * 0.3 + 4. * 0.5 + 6. * 0.2) / 3)
# Time-step sample weight # Time-step sample weight
per_replica_losses = distribution.experimental_run_v2( per_replica_losses = distribution.run(
nn_impl.compute_average_loss, nn_impl.compute_average_loss,
args=([[2., 0.5], [4., 1.]],), args=([[2., 0.5], [4., 1.]],),
kwargs={"sample_weight": [[0.3, 0.7], [0.2, 0.8]]}) kwargs={"sample_weight": [[0.3, 0.7], [0.2, 0.8]]})
@ -114,7 +114,7 @@ class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase):
with distribution.scope(): with distribution.scope():
per_example_loss = constant_op.constant([2., 4., 6.], per_example_loss = constant_op.constant([2., 4., 6.],
dtype=dtypes.float64) dtype=dtypes.float64)
per_replica_losses = distribution.experimental_run_v2( per_replica_losses = distribution.run(
nn_impl.compute_average_loss, nn_impl.compute_average_loss,
args=(per_example_loss,), args=(per_example_loss,),
kwargs={"sample_weight": 2}) kwargs={"sample_weight": 2})
@ -169,7 +169,7 @@ class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase):
# With strategy - num replicas = 2 # With strategy - num replicas = 2
with distribution.scope(): with distribution.scope():
per_replica_losses = distribution.experimental_run_v2( per_replica_losses = distribution.run(
nn_impl.scale_regularization_loss, args=(reg_losses,)) nn_impl.scale_regularization_loss, args=(reg_losses,))
loss = distribution.reduce("SUM", per_replica_losses, axis=None) loss = distribution.reduce("SUM", per_replica_losses, axis=None)
self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.)) self.assertAllClose(self.evaluate(loss), (2.5 + 6.2 + 5.))

View File

@ -687,7 +687,7 @@ def outside_compilation(computation, *args, **kwargs):
`tf.tpu.outside_compilation()` should be called inside a function that is `tf.tpu.outside_compilation()` should be called inside a function that is
passed to `tpu.split_compile_and_replicate()` -- this is implied when passed to `tpu.split_compile_and_replicate()` -- this is implied when
outside compilation is invoked inside a function passed to TPUStrategy outside compilation is invoked inside a function passed to TPUStrategy
`experimental_run_v2()`. If invoked outside of TPUReplicateContext, `run()`. If invoked outside of TPUReplicateContext,
then this simply returns the result of `computation`, and therefore, then this simply returns the result of `computation`, and therefore,
would be a no-op. Note that outside compilation is different from would be a no-op. Note that outside compilation is different from
`tf.distribute.experimental.TPUStrategy.merge_call()` as logic in `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in

View File

@ -302,10 +302,11 @@ def _compute_gradients_until_finite(
return grads return grads
# Switch to a replica-context to compute gradients once per replica. # Switch to a replica-context to compute gradients once per replica.
grads = distribution.experimental_run_v2( grads = distribution.run(
replica_fn, args=(loss_scale_gradient_tapes, target, flattened_sources, replica_fn,
output_gradients, initial_grads)) args=(loss_scale_gradient_tapes, target, flattened_sources,
# Check for non-finite gradients possibly resulting from scaling. output_gradients, initial_grads))
# Check for non-finite gradients possibly resulting from scaling
_, ready_to_update = loss_scale.update(grads) _, ready_to_update = loss_scale.update(grads)
is_first_iteration = False is_first_iteration = False
return grads, ready_to_update, is_first_iteration return grads, ready_to_update, is_first_iteration

View File

@ -54,7 +54,7 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
def _run_with_strategy(self, run_fn, strategy, use_tf_function=False): def _run_with_strategy(self, run_fn, strategy, use_tf_function=False):
"""Runs `run_fn` under the DistributionStrategy `strategy`. """Runs `run_fn` under the DistributionStrategy `strategy`.
Runs `run_fn` with `strategy.experimental_run_v2`. Returns a list of the Runs `run_fn` with `strategy.run`. Returns a list of the
return values of `run_fn`, one per replica. return values of `run_fn`, one per replica.
Args: Args:
@ -67,7 +67,7 @@ class LossScaleGradientTapeTest(test.TestCase, parameterized.TestCase):
replica. If a nested structure is returned from `run_fn`, returns a replica. If a nested structure is returned from `run_fn`, returns a
nested structure, where each element is a list of tensors. nested structure, where each element is a list of tensors.
""" """
strategy_fn = lambda: strategy.experimental_run_v2(run_fn) strategy_fn = lambda: strategy.run(run_fn)
if use_tf_function: if use_tf_function:
strategy_fn = def_function.function(strategy_fn) strategy_fn = def_function.function(strategy_fn)

View File

@ -64,6 +64,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -64,6 +64,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -63,6 +63,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -64,6 +64,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -64,6 +64,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -64,6 +64,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -68,6 +68,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -80,6 +80,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -80,6 +80,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -79,6 +79,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -80,6 +80,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -80,6 +80,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -80,6 +80,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -80,6 +80,10 @@ tf_class {
name: "reduce" name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method { member_method {
name: "scope" name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -829,7 +829,7 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
"custom training loop, note the following changes in methods: " "custom training loop, note the following changes in methods: "
"make_dataset_iterator->experimental_distribute_dataset, " "make_dataset_iterator->experimental_distribute_dataset, "
"experimental_make_numpy_iterator->experimental_make_numpy_dataset, " "experimental_make_numpy_iterator->experimental_make_numpy_dataset, "
"extended.call_for_each_replica->experimental_run_v2, " "extended.call_for_each_replica->run, "
"reduce requires an axis argument, " "reduce requires an axis argument, "
"unwrap->experimental_local_results " "unwrap->experimental_local_results "
"experimental_initialize and experimental_finalize no longer needed ") "experimental_initialize and experimental_finalize no longer needed ")