From 4e57040e940b4bdc2fc91e2f0b0f3dbbf1a59f6d Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Tue, 18 Dec 2018 23:21:26 -0800 Subject: [PATCH] Remove the wrapping of single inputs in experimental run steps PiperOrigin-RevId: 226121383 --- .../contrib/distribute/python/metrics_v1_test.py | 4 ++-- .../distribute/python/minimize_loss_test.py | 14 +++++++------- .../distribute/python/one_device_strategy.py | 5 +---- tensorflow/contrib/distribute/python/step_fn.py | 2 +- .../contrib/distribute/python/tpu_strategy.py | 5 +---- tensorflow/python/distribute/mirrored_strategy.py | 5 +---- 6 files changed, 13 insertions(+), 22 deletions(-) diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 8ac659abe96..32a0d199434 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -100,7 +100,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): value, update = distribution.call_for_each_replica( - metric_fn, args=inputs) + metric_fn, args=(inputs,)) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) @@ -115,7 +115,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): distribution.extended.steps_per_run) else: value, update = distribution.call_for_each_replica( - metric_fn, iterator.get_next()) + metric_fn, args=(iterator.get_next(),)) update = distribution.group(update) # TODO(josh11b): Once we switch to using a global batch size for input, # replace "distribution.num_replicas_in_sync" with "1". diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index f09483cb56b..824c4b09371 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -67,7 +67,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.call_for_each_replica(model_fn, args=(inputs,))) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -161,7 +161,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.call_for_each_replica(model_fn, args=(inputs,))) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -230,7 +230,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused fetches = distribution.unwrap( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.call_for_each_replica(model_fn, args=(inputs,))) if update_ops_in_cross_replica_mode: fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) return control_flow_ops.group(fetches) @@ -302,8 +302,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): with distribution.scope(): all_vars = [] - def model_fn(x, y): - + def model_fn(inputs): + x, y = inputs def loss_fn(): # Use fixed initialization to make the steps deterministic. w = variable_scope.get_variable("w", initializer=[[2.]]) @@ -327,7 +327,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.call_for_each_replica(model_fn, args=(inputs,))) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -413,7 +413,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(output_context, inputs): (train_op, loss) = distribution.call_for_each_replica( - model_fn, args=(output_context,) + inputs) + model_fn, args=(output_context, inputs)) output_context.set_last_step_output( name="cross_replica_loss_reduced", output=loss, diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 4b60f3c786a..c9ea706b646 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -101,10 +101,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args - fn_inputs = iterator.get_next() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, fn_inputs) + fn_result = fn(ctx, iterator.get_next()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index c928b6d9f1f..faeb96bcb7c 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -100,7 +100,7 @@ class StandardSingleLossStep(StandardInputStep): gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) grads_and_vars = self.distribution.call_for_each_replica( - gradients_fn, args=(ctx,) + inputs) + gradients_fn, args=(ctx, inputs)) # If threads use layers, then we need to run the first step # sequentially, so that layers.build() is not executed in parallel. # Otherwise, multiple sets of mirrored variables are going to be diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index bdcad14704d..c2f62c3ca23 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -331,10 +331,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def run_fn(): """Single step on the TPU device.""" - fn_inputs = dequeue_fn() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, fn_inputs) + fn_result = fn(ctx, dequeue_fn()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index b4f9761b980..60b5232e164 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -574,10 +574,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args - fn_inputs = iterator.get_next() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, fn_inputs) + fn_result = fn(ctx, iterator.get_next()) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self._unwrap(output)