From 9b37e099942b468ae062d4494c2618358c0e864c Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Mon, 8 Jun 2020 00:33:42 -0700 Subject: [PATCH] Do not capture device from outer stack in a func graph when using distribution strategies in eager mode. PiperOrigin-RevId: 315224550 Change-Id: I92e5a3ddea86e2365758ff4bf1b4f6a03946b1e4 --- tensorflow/python/distribute/BUILD | 16 +++ .../python/distribute/distribute_lib.py | 6 +- .../python/distribute/tf_function_test.py | 131 ++++++++++++++++++ tensorflow/python/framework/func_graph.py | 26 ++-- 4 files changed, 161 insertions(+), 18 deletions(-) create mode 100644 tensorflow/python/distribute/tf_function_test.py diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 26027d46c98..977452cdad1 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1791,3 +1791,19 @@ cuda_py_test( "@absl_py//absl/testing:parameterized", ], ) + +distribute_py_test( + name = "tf_function_test", + srcs = ["tf_function_test.py"], + main = "tf_function_test.py", + tags = [ + "multi_and_single_gpu", + ], + deps = [ + ":combinations", + ":strategy_combinations", + "//tensorflow/python:array_ops", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/eager:test", + ], +) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index b77163cb97a..5ea738765b7 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -620,10 +620,8 @@ class StrategyBase(object): if not hasattr(extended, "_retrace_functions_for_each_device"): # pylint: disable=protected-access # `extended._retrace_functions_for_each_device` dictates - # 1) whether all the ops created inside function will have devices - # inherited from outer stack, and - # 2) whether the same function will be retraced when it is called on - # different devices. + # whether the same function will be retraced when it is called on + # different devices. try: extended._retrace_functions_for_each_device = ( len(extended.worker_devices) > 1) diff --git a/tensorflow/python/distribute/tf_function_test.py b/tensorflow/python/distribute/tf_function_test.py new file mode 100644 index 00000000000..5dc82cfd81b --- /dev/null +++ b/tensorflow/python/distribute/tf_function_test.py @@ -0,0 +1,131 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.function + distribution strategies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.compat import v2_compat +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import values +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables + + +class TFFunctionTest(test.TestCase, parameterized.TestCase): + + def setup(self): + # Clear the state for every test. + def_function.run_functions_eagerly(False) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, + mode=["eager"], + run_functions_eagerly=[True, False] + )) + def testDefaultDeviceInsideFunctionWithScope( + self, distribution, run_functions_eagerly): + + def_function.run_functions_eagerly(run_functions_eagerly) + expected_device = (device_util.canonicalize("cpu:0") + if run_functions_eagerly else "") + with distribution.scope(): + with ops.device_v2("cpu:0"): + @def_function.function + def add(): + one = array_ops.ones([]) + self.assertEqual(expected_device, one.device) + return one + 1 + + add() + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, + mode=["eager"], + run_functions_eagerly=[True, False] + )) + def testDefaultDeviceInsideNestedFunctionWithScope( + self, distribution, run_functions_eagerly): + + def_function.run_functions_eagerly(run_functions_eagerly) + expected_device = (device_util.canonicalize("cpu:0") + if run_functions_eagerly else "") + with distribution.scope(): + @def_function.function + def foo(): + with ops.device("cpu:0"): + + @def_function.function + def bar(): + one = array_ops.ones([]) + self.assertEqual(expected_device, one.device) + return one + 1 + + bar() + + foo() + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, + mode=["eager"], + run_functions_eagerly=[True, False] + )) + def testReadVariableInsideFunction(self, distribution, run_functions_eagerly): + + # Get devices on which variables will be placed. Default strategy does not + # define this, so assume cpu:0 in that case. + try: + devices = distribution.extended.parameter_devices + except RuntimeError: + devices = ["cpu:0"] + + with distribution.scope(): + v = variables.Variable(0.) + if isinstance(v, values.DistributedVariable): + for i in range(len(devices)): + # NOTE: Assigning manually to component variables so we can test + # different values on different devices. Using .assign on the + # mirrored variable itself will lead to a synchronization which + # will prohibit testing different values. + replica_variable = v._values[i] + replica_variable.assign(math_ops.cast(i, dtypes.float32)) + + @def_function.function + def read(): + return v.read_value() + + for i, d in enumerate(devices): + with ops.device(d): + # Verify that the value from each device is read, when in that device + # scope. + self.assertEqual(math_ops.cast(i, dtypes.float32), read()) + + +if __name__ == "__main__": + v2_compat.enable_v2_behavior() + test.main() diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 6b358a3c51a..b0f8821b17f 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -365,35 +365,33 @@ class FuncGraph(ops.Graph): @tf_contextlib.contextmanager def inner_cm(): """Context manager for copying distribute.Strategy scope information.""" - graph = ops.get_default_graph() # pylint: disable=protected-access # TODO(b/112906995, nareshmodi): distribution strategy depends on # inheriting this stack from the default graph even in eager mode. Maybe # it should be part of the eager context? This would also allow us to # remove a get_default_graph() call from the function cache lookup. + graph = ops.get_default_graph() old_strategy_stack = self._distribution_strategy_stack self._distribution_strategy_stack = list( graph._distribution_strategy_stack) - uses_distribution_strategy = ( - self._distribution_strategy_stack and - self._distribution_strategy_stack[-1].strategy.extended - ._retrace_functions_for_each_device) + # We ignore device placements from any outer scopes while tracing the # function when possible, to avoid hard-coding them in the function # graph. "Default" placements come from the PartitionedCallOp's placement, # so that the same trace of the Python function may be placed on several # different devices and saved functions may be placed on new devices when # restored. + # However, we need to preserve the outer device stack in the following + # cases in non eager context: + # 1. device stack is callable + # 2. When using distribution strategy with legacy graph mode. old_device_stack = self._device_function_stack - if context.executing_eagerly(): - if uses_distribution_strategy: - self._device_function_stack = self._device_function_stack.copy() - self._add_device_to_stack(context.context().device_name) - else: - if (uses_distribution_strategy or - device_stack_has_callable(graph._device_function_stack)): - # Hard-code devices from device functions in the function body - self._device_function_stack = graph._device_function_stack.copy() + if (not context.executing_eagerly() and + (device_stack_has_callable(graph._device_function_stack) or + (self._distribution_strategy_stack and + not ops.executing_eagerly_outside_functions()))): + # Hard-code devices from device functions in the function body + self._device_function_stack = graph._device_function_stack.copy() old_creator_stack = self._variable_creator_stack self._variable_creator_stack = graph._variable_creator_stack