From 6136078913c450c23cebfd9443b16abc33202b75 Mon Sep 17 00:00:00 2001 From: Yujing Zhang <yujingzhang@google.com> Date: Tue, 15 Dec 2020 22:05:53 -0800 Subject: [PATCH] Return the primary handle when it's in graph mode and not under a tpu context. PiperOrigin-RevId: 347756732 Change-Id: Iaaaf9052797d87ba768c5f65bd9a6e1adfb1b6da --- .../distribute/packed_distributed_variable.py | 4 ++++ tensorflow/python/distribute/tpu_strategy_test.py | 13 +++++++++++++ tensorflow/python/distribute/tpu_values.py | 6 +++++- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/packed_distributed_variable.py b/tensorflow/python/distribute/packed_distributed_variable.py index 4c9433dc164..a1584119e99 100644 --- a/tensorflow/python/distribute/packed_distributed_variable.py +++ b/tensorflow/python/distribute/packed_distributed_variable.py @@ -282,6 +282,10 @@ class PackedVarAndDevice(object): with ops.device(self._device): return self._var.handle + def on_device_handle(self): + with ops.device(self._device): + return self._var.get_var_on_current_device().handle + @property def op(self): with ops.device(self._device): diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index 239882c1571..9f5fdb04d1e 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -167,6 +167,19 @@ class TPUTest(test.TestCase): @parameterized.named_parameters([("PackedVar", True), ("", False)]) class TPUStrategyTest(test.TestCase, parameterized.TestCase): + def test_handle_in_cross_replica_context(self, enable_packed_var): + strategy = get_tpu_strategy(enable_packed_var) + with strategy.scope(): + v = variables.Variable(1.0) + + @def_function.function + def func(): + self.assertEndsWith(v.handle.device, "device:TPU:0") + return v + 1.0 + + ret = func() + self.assertAllEqual(ret, 2.0) + def test_function_compile_with_xla(self, enable_packed_var): strategy = get_tpu_strategy(enable_packed_var) with strategy.scope(): diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py index 3094f74372e..dbe1e1ff363 100644 --- a/tensorflow/python/distribute/tpu_values.py +++ b/tensorflow/python/distribute/tpu_values.py @@ -115,7 +115,11 @@ class TPUVariableMixin(object): # If we're in a tpu.rewrite(), return the replicated handle. tpu_context = enclosing_tpu_context() if tpu_context is None or context.executing_eagerly(): - return self._get_on_device_or_primary().handle + var = self._get_on_device_or_primary() + if isinstance(var, packed.PackedVarAndDevice): + return var.on_device_handle() + else: + return var.handle else: is_packed = self._packed_var is not None val = self._values