Remove the wrapping of single inputs in experimental run steps
PiperOrigin-RevId: 226121383
This commit is contained in:
parent
627fd023a0
commit
4e57040e94
@ -100,7 +100,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
||||
if isinstance(distribution, tpu_strategy.TPUStrategy):
|
||||
def step_fn(ctx, inputs):
|
||||
value, update = distribution.call_for_each_replica(
|
||||
metric_fn, args=inputs)
|
||||
metric_fn, args=(inputs,))
|
||||
ctx.set_non_tensor_output(name="value", output=value)
|
||||
return distribution.group(update)
|
||||
|
||||
@ -115,7 +115,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
||||
distribution.extended.steps_per_run)
|
||||
else:
|
||||
value, update = distribution.call_for_each_replica(
|
||||
metric_fn, iterator.get_next())
|
||||
metric_fn, args=(iterator.get_next(),))
|
||||
update = distribution.group(update)
|
||||
# TODO(josh11b): Once we switch to using a global batch size for input,
|
||||
# replace "distribution.num_replicas_in_sync" with "1".
|
||||
|
@ -67,7 +67,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
def step_fn(ctx, inputs):
|
||||
del ctx # Unused
|
||||
return distribution.group(
|
||||
distribution.call_for_each_replica(model_fn, args=inputs))
|
||||
distribution.call_for_each_replica(model_fn, args=(inputs,)))
|
||||
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
|
||||
@ -161,7 +161,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
def step_fn(ctx, inputs):
|
||||
del ctx # Unused
|
||||
return distribution.group(
|
||||
distribution.call_for_each_replica(model_fn, args=inputs))
|
||||
distribution.call_for_each_replica(model_fn, args=(inputs,)))
|
||||
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
|
||||
@ -230,7 +230,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
def step_fn(ctx, inputs):
|
||||
del ctx # Unused
|
||||
fetches = distribution.unwrap(
|
||||
distribution.call_for_each_replica(model_fn, args=inputs))
|
||||
distribution.call_for_each_replica(model_fn, args=(inputs,)))
|
||||
if update_ops_in_cross_replica_mode:
|
||||
fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
|
||||
return control_flow_ops.group(fetches)
|
||||
@ -302,8 +302,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
with distribution.scope():
|
||||
all_vars = []
|
||||
|
||||
def model_fn(x, y):
|
||||
|
||||
def model_fn(inputs):
|
||||
x, y = inputs
|
||||
def loss_fn():
|
||||
# Use fixed initialization to make the steps deterministic.
|
||||
w = variable_scope.get_variable("w", initializer=[[2.]])
|
||||
@ -327,7 +327,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
def step_fn(ctx, inputs):
|
||||
del ctx # Unused
|
||||
return distribution.group(
|
||||
distribution.call_for_each_replica(model_fn, args=inputs))
|
||||
distribution.call_for_each_replica(model_fn, args=(inputs,)))
|
||||
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
|
||||
@ -413,7 +413,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def step_fn(output_context, inputs):
|
||||
(train_op, loss) = distribution.call_for_each_replica(
|
||||
model_fn, args=(output_context,) + inputs)
|
||||
model_fn, args=(output_context, inputs))
|
||||
output_context.set_last_step_output(
|
||||
name="cross_replica_loss_reduced",
|
||||
output=loss,
|
||||
|
@ -101,10 +101,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
|
||||
def body(i, *args):
|
||||
"""A wrapper around `fn` to create the while loop body."""
|
||||
del args
|
||||
fn_inputs = iterator.get_next()
|
||||
if not isinstance(fn_inputs, tuple):
|
||||
fn_inputs = (fn_inputs,)
|
||||
fn_result = fn(ctx, fn_inputs)
|
||||
fn_result = fn(ctx, iterator.get_next())
|
||||
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
|
||||
with ops.control_dependencies([fn_result]):
|
||||
return [i + 1] + flat_last_step_outputs
|
||||
|
@ -100,7 +100,7 @@ class StandardSingleLossStep(StandardInputStep):
|
||||
gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
|
||||
|
||||
grads_and_vars = self.distribution.call_for_each_replica(
|
||||
gradients_fn, args=(ctx,) + inputs)
|
||||
gradients_fn, args=(ctx, inputs))
|
||||
# If threads use layers, then we need to run the first step
|
||||
# sequentially, so that layers.build() is not executed in parallel.
|
||||
# Otherwise, multiple sets of mirrored variables are going to be
|
||||
|
@ -331,10 +331,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
|
||||
def run_fn():
|
||||
"""Single step on the TPU device."""
|
||||
fn_inputs = dequeue_fn()
|
||||
if not isinstance(fn_inputs, tuple):
|
||||
fn_inputs = (fn_inputs,)
|
||||
fn_result = fn(ctx, fn_inputs)
|
||||
fn_result = fn(ctx, dequeue_fn())
|
||||
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
|
||||
if flat_last_step_outputs:
|
||||
with ops.control_dependencies([fn_result]):
|
||||
|
@ -574,10 +574,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
|
||||
def body(i, *args):
|
||||
"""A wrapper around `fn` to create the while loop body."""
|
||||
del args
|
||||
fn_inputs = iterator.get_next()
|
||||
if not isinstance(fn_inputs, tuple):
|
||||
fn_inputs = (fn_inputs,)
|
||||
fn_result = fn(ctx, fn_inputs)
|
||||
fn_result = fn(ctx, iterator.get_next())
|
||||
for (name, output) in ctx.last_step_outputs.items():
|
||||
# Convert all outputs to tensors, potentially from `DistributedValues`.
|
||||
ctx.last_step_outputs[name] = self._unwrap(output)
|
||||
|
Loading…
x
Reference in New Issue
Block a user