Remove the wrapping of single inputs in experimental run steps

PiperOrigin-RevId: 226121383
This commit is contained in:
Sourabh Bajaj 2018-12-18 23:21:26 -08:00 committed by TensorFlower Gardener
parent 627fd023a0
commit 4e57040e94
6 changed files with 13 additions and 22 deletions

View File

@ -100,7 +100,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
if isinstance(distribution, tpu_strategy.TPUStrategy):
def step_fn(ctx, inputs):
value, update = distribution.call_for_each_replica(
metric_fn, args=inputs)
metric_fn, args=(inputs,))
ctx.set_non_tensor_output(name="value", output=value)
return distribution.group(update)
@ -115,7 +115,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
distribution.extended.steps_per_run)
else:
value, update = distribution.call_for_each_replica(
metric_fn, iterator.get_next())
metric_fn, args=(iterator.get_next(),))
update = distribution.group(update)
# TODO(josh11b): Once we switch to using a global batch size for input,
# replace "distribution.num_replicas_in_sync" with "1".

View File

@ -67,7 +67,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def step_fn(ctx, inputs):
del ctx # Unused
return distribution.group(
distribution.call_for_each_replica(model_fn, args=inputs))
distribution.call_for_each_replica(model_fn, args=(inputs,)))
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
@ -161,7 +161,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def step_fn(ctx, inputs):
del ctx # Unused
return distribution.group(
distribution.call_for_each_replica(model_fn, args=inputs))
distribution.call_for_each_replica(model_fn, args=(inputs,)))
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
@ -230,7 +230,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def step_fn(ctx, inputs):
del ctx # Unused
fetches = distribution.unwrap(
distribution.call_for_each_replica(model_fn, args=inputs))
distribution.call_for_each_replica(model_fn, args=(inputs,)))
if update_ops_in_cross_replica_mode:
fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
return control_flow_ops.group(fetches)
@ -302,8 +302,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
with distribution.scope():
all_vars = []
def model_fn(x, y):
def model_fn(inputs):
x, y = inputs
def loss_fn():
# Use fixed initialization to make the steps deterministic.
w = variable_scope.get_variable("w", initializer=[[2.]])
@ -327,7 +327,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def step_fn(ctx, inputs):
del ctx # Unused
return distribution.group(
distribution.call_for_each_replica(model_fn, args=inputs))
distribution.call_for_each_replica(model_fn, args=(inputs,)))
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
@ -413,7 +413,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def step_fn(output_context, inputs):
(train_op, loss) = distribution.call_for_each_replica(
model_fn, args=(output_context,) + inputs)
model_fn, args=(output_context, inputs))
output_context.set_last_step_output(
name="cross_replica_loss_reduced",
output=loss,

View File

@ -101,10 +101,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args
fn_inputs = iterator.get_next()
if not isinstance(fn_inputs, tuple):
fn_inputs = (fn_inputs,)
fn_result = fn(ctx, fn_inputs)
fn_result = fn(ctx, iterator.get_next())
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs

View File

@ -100,7 +100,7 @@ class StandardSingleLossStep(StandardInputStep):
gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
grads_and_vars = self.distribution.call_for_each_replica(
gradients_fn, args=(ctx,) + inputs)
gradients_fn, args=(ctx, inputs))
# If threads use layers, then we need to run the first step
# sequentially, so that layers.build() is not executed in parallel.
# Otherwise, multiple sets of mirrored variables are going to be

View File

@ -331,10 +331,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
def run_fn():
"""Single step on the TPU device."""
fn_inputs = dequeue_fn()
if not isinstance(fn_inputs, tuple):
fn_inputs = (fn_inputs,)
fn_result = fn(ctx, fn_inputs)
fn_result = fn(ctx, dequeue_fn())
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
if flat_last_step_outputs:
with ops.control_dependencies([fn_result]):

View File

@ -574,10 +574,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args
fn_inputs = iterator.get_next()
if not isinstance(fn_inputs, tuple):
fn_inputs = (fn_inputs,)
fn_result = fn(ctx, fn_inputs)
fn_result = fn(ctx, iterator.get_next())
for (name, output) in ctx.last_step_outputs.items():
# Convert all outputs to tensors, potentially from `DistributedValues`.
ctx.last_step_outputs[name] = self._unwrap(output)