tf.distribute.Strategy.reduce: Remove default value for axis argument for Strategy V2, but keep for Strategy V1.
Change all callers to specify the argument, so their code continues to work with Strategy V2. For now most of the callers use axis=None. If this continues to remain the most common case, we can consider adding back the default in the future. PiperOrigin-RevId: 243462195
This commit is contained in:
parent
06d2ea93fa
commit
471993acd6
@ -293,7 +293,7 @@ class CollectiveAllReduceStrategyTestBase(
|
|||||||
return array_ops.identity(x)
|
return array_ops.identity(x)
|
||||||
|
|
||||||
x = distribution.extended.call_for_each_replica(model_fn)
|
x = distribution.extended.call_for_each_replica(model_fn)
|
||||||
reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x)
|
reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x, axis=None)
|
||||||
x = distribution.experimental_local_results(x)[0]
|
x = distribution.experimental_local_results(x)[0]
|
||||||
|
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
|
@ -438,7 +438,7 @@ class Strategy(object):
|
|||||||
with self.scope():
|
with self.scope():
|
||||||
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
|
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
|
||||||
|
|
||||||
def reduce(self, reduce_op, value, axis=None):
|
def reduce(self, reduce_op, value, axis):
|
||||||
"""Reduce `value` across replicas.
|
"""Reduce `value` across replicas.
|
||||||
|
|
||||||
Given a per-replica value returned by `experimental_run_v2`, say a
|
Given a per-replica value returned by `experimental_run_v2`, say a
|
||||||
@ -468,8 +468,10 @@ class Strategy(object):
|
|||||||
be combined.
|
be combined.
|
||||||
value: A "per replica" value, e.g. returned by `experimental_run_v2` to
|
value: A "per replica" value, e.g. returned by `experimental_run_v2` to
|
||||||
be combined into a single tensor.
|
be combined into a single tensor.
|
||||||
axis: Optional. Specifies the dimension to reduce along within each
|
axis: Specifies the dimension to reduce along within each
|
||||||
replica's tensor. Should typically be set to the batch dimension.
|
replica's tensor. Should typically be set to the batch dimension, or
|
||||||
|
`None` to only reduce across replicas (e.g. if the tensor has no batch
|
||||||
|
dimension).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Tensor`.
|
A `Tensor`.
|
||||||
@ -729,6 +731,11 @@ class StrategyV1(Strategy):
|
|||||||
return super(StrategyV1, self).experimental_run(
|
return super(StrategyV1, self).experimental_run(
|
||||||
fn, input_iterator)
|
fn, input_iterator)
|
||||||
|
|
||||||
|
def reduce(self, reduce_op, value, axis=None):
|
||||||
|
return super(StrategyV1, self).reduce(reduce_op, value, axis)
|
||||||
|
|
||||||
|
reduce.__doc__ = Strategy.reduce.__doc__
|
||||||
|
|
||||||
def update_config_proto(self, config_proto):
|
def update_config_proto(self, config_proto):
|
||||||
"""Returns a copy of `config_proto` modified for use with this strategy.
|
"""Returns a copy of `config_proto` modified for use with this strategy.
|
||||||
|
|
||||||
|
@ -321,7 +321,7 @@ class TestStrategyTest(test.TestCase):
|
|||||||
@_run_in_and_out_of_scope
|
@_run_in_and_out_of_scope
|
||||||
def testReduce(self, dist):
|
def testReduce(self, dist):
|
||||||
x = constant_op.constant(1.)
|
x = constant_op.constant(1.)
|
||||||
x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x)
|
x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=None)
|
||||||
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
||||||
|
|
||||||
def testReductions_acceptStringOps(self):
|
def testReductions_acceptStringOps(self):
|
||||||
@ -329,7 +329,7 @@ class TestStrategyTest(test.TestCase):
|
|||||||
for op in ("mean", "MEAN", "sum", "SUM"):
|
for op in ("mean", "MEAN", "sum", "SUM"):
|
||||||
x = constant_op.constant(1.)
|
x = constant_op.constant(1.)
|
||||||
y = constant_op.constant(1.)
|
y = constant_op.constant(1.)
|
||||||
x_r = dist.reduce(op, x)
|
x_r = dist.reduce(op, x, axis=None)
|
||||||
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
||||||
x_r = dist.extended.reduce_to(op, x, "/CPU:0")
|
x_r = dist.extended.reduce_to(op, x, "/CPU:0")
|
||||||
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
||||||
|
@ -804,11 +804,13 @@ class MultiStepContext(object):
|
|||||||
self._last_step_outputs[name] = output
|
self._last_step_outputs[name] = output
|
||||||
else:
|
else:
|
||||||
distribution = distribution_strategy_context.get_strategy()
|
distribution = distribution_strategy_context.get_strategy()
|
||||||
self._last_step_outputs[name] = distribution.reduce(reduce_op, output)
|
self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
|
||||||
|
axis=None)
|
||||||
else:
|
else:
|
||||||
assert reduce_op is not None
|
assert reduce_op is not None
|
||||||
def merge_fn(distribution, value):
|
def merge_fn(distribution, value):
|
||||||
self._last_step_outputs[name] = distribution.reduce(reduce_op, value)
|
self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
|
||||||
|
axis=None)
|
||||||
# Setting this inside the `merge_fn` because all replicas share the same
|
# Setting this inside the `merge_fn` because all replicas share the same
|
||||||
# context object, so it's more robust to set it only once (even if all
|
# context object, so it's more robust to set it only once (even if all
|
||||||
# the replicas are trying to set the same value).
|
# the replicas are trying to set the same value).
|
||||||
|
@ -514,7 +514,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
if not reduced:
|
if not reduced:
|
||||||
self.assertLen(distribution.experimental_local_results(loss_output),
|
self.assertLen(distribution.experimental_local_results(loss_output),
|
||||||
distribution.num_replicas_in_sync)
|
distribution.num_replicas_in_sync)
|
||||||
loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN, loss_output)
|
loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN, loss_output,
|
||||||
|
axis=None)
|
||||||
else:
|
else:
|
||||||
unwrapped_output = distribution.experimental_local_results(loss_output)
|
unwrapped_output = distribution.experimental_local_results(loss_output)
|
||||||
self.assertLen(unwrapped_output, 1)
|
self.assertLen(unwrapped_output, 1)
|
||||||
|
@ -103,7 +103,7 @@ class MirroredTwoDeviceDistributionTest(
|
|||||||
def testReduceToCpu(self, distribution):
|
def testReduceToCpu(self, distribution):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
result = distribution.extended.call_for_each_replica(_replica_id)
|
result = distribution.extended.call_for_each_replica(_replica_id)
|
||||||
reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result)
|
reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=None)
|
||||||
expected = sum(range(distribution.num_replicas_in_sync))
|
expected = sum(range(distribution.num_replicas_in_sync))
|
||||||
self.assertEqual(expected, self.evaluate(reduced))
|
self.assertEqual(expected, self.evaluate(reduced))
|
||||||
|
|
||||||
|
@ -1310,7 +1310,8 @@ class SyncOnReadVariable(DistributedVariable, PerReplica):
|
|||||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||||
return self.primary
|
return self.primary
|
||||||
return self._distribute_strategy.reduce(
|
return self._distribute_strategy.reduce(
|
||||||
reduce_util.ReduceOp.from_variable_aggregation(self.aggregation), self)
|
reduce_util.ReduceOp.from_variable_aggregation(self.aggregation), self,
|
||||||
|
axis=None)
|
||||||
|
|
||||||
def _as_graph_element(self):
|
def _as_graph_element(self):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
@ -107,7 +107,7 @@ def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
|
|||||||
if with_loss_tensor:
|
if with_loss_tensor:
|
||||||
# reduce loss tensor before adding it to the list of fetches
|
# reduce loss tensor before adding it to the list of fetches
|
||||||
loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM,
|
loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM,
|
||||||
grouped_outputs[0])
|
grouped_outputs[0], axis=None)
|
||||||
all_outputs = flatten_perdevice_values(distribution_strategy,
|
all_outputs = flatten_perdevice_values(distribution_strategy,
|
||||||
grouped_outputs[1:])
|
grouped_outputs[1:])
|
||||||
all_outputs = [loss] + all_outputs
|
all_outputs = [loss] + all_outputs
|
||||||
|
@ -497,7 +497,8 @@ def experimental_tpu_test_loop(model,
|
|||||||
# We reduce all other metrics using mean for now. This is temporary
|
# We reduce all other metrics using mean for now. This is temporary
|
||||||
# workaround until new metrics are in place.
|
# workaround until new metrics are in place.
|
||||||
reduce_op = ds_reduce_util.ReduceOp.MEAN
|
reduce_op = ds_reduce_util.ReduceOp.MEAN
|
||||||
output_tensors[label] = current_strategy.reduce(reduce_op, output)
|
output_tensors[label] = current_strategy.reduce(reduce_op, output,
|
||||||
|
axis=None)
|
||||||
test_op = control_flow_ops.group(list(output_tensors.values()))
|
test_op = control_flow_ops.group(list(output_tensors.values()))
|
||||||
|
|
||||||
if verbose >= 1:
|
if verbose >= 1:
|
||||||
|
@ -289,7 +289,7 @@ class DynamicLossScale(LossScale):
|
|||||||
is_finite_float = distribution.extended.call_for_each_replica(
|
is_finite_float = distribution.extended.call_for_each_replica(
|
||||||
get_is_finite, args=(grads,))
|
get_is_finite, args=(grads,))
|
||||||
reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
|
reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
|
||||||
is_finite_float)
|
is_finite_float, axis=None)
|
||||||
is_finite = math_ops.equal(reduced_is_finite_float,
|
is_finite = math_ops.equal(reduced_is_finite_float,
|
||||||
distribution.num_replicas_in_sync)
|
distribution.num_replicas_in_sync)
|
||||||
else:
|
else:
|
||||||
|
@ -53,7 +53,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "reduce"
|
name: "reduce"
|
||||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "scope"
|
name: "scope"
|
||||||
|
@ -53,7 +53,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "reduce"
|
name: "reduce"
|
||||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "scope"
|
name: "scope"
|
||||||
|
@ -52,7 +52,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "reduce"
|
name: "reduce"
|
||||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "scope"
|
name: "scope"
|
||||||
|
@ -53,7 +53,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "reduce"
|
name: "reduce"
|
||||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "scope"
|
name: "scope"
|
||||||
|
@ -53,7 +53,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "reduce"
|
name: "reduce"
|
||||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "scope"
|
name: "scope"
|
||||||
|
@ -53,7 +53,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "reduce"
|
name: "reduce"
|
||||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "scope"
|
name: "scope"
|
||||||
|
@ -53,7 +53,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "reduce"
|
name: "reduce"
|
||||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "scope"
|
name: "scope"
|
||||||
|
Loading…
Reference in New Issue
Block a user