From 6136078913c450c23cebfd9443b16abc33202b75 Mon Sep 17 00:00:00 2001
From: Yujing Zhang <yujingzhang@google.com>
Date: Tue, 15 Dec 2020 22:05:53 -0800
Subject: [PATCH] Return the primary handle when it's in graph mode and not
 under a tpu context.

PiperOrigin-RevId: 347756732
Change-Id: Iaaaf9052797d87ba768c5f65bd9a6e1adfb1b6da
---
 .../distribute/packed_distributed_variable.py       |  4 ++++
 tensorflow/python/distribute/tpu_strategy_test.py   | 13 +++++++++++++
 tensorflow/python/distribute/tpu_values.py          |  6 +++++-
 3 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/tensorflow/python/distribute/packed_distributed_variable.py b/tensorflow/python/distribute/packed_distributed_variable.py
index 4c9433dc164..a1584119e99 100644
--- a/tensorflow/python/distribute/packed_distributed_variable.py
+++ b/tensorflow/python/distribute/packed_distributed_variable.py
@@ -282,6 +282,10 @@ class PackedVarAndDevice(object):
     with ops.device(self._device):
       return self._var.handle
 
+  def on_device_handle(self):
+    with ops.device(self._device):
+      return self._var.get_var_on_current_device().handle
+
   @property
   def op(self):
     with ops.device(self._device):
diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py
index 239882c1571..9f5fdb04d1e 100644
--- a/tensorflow/python/distribute/tpu_strategy_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_test.py
@@ -167,6 +167,19 @@ class TPUTest(test.TestCase):
 @parameterized.named_parameters([("PackedVar", True), ("", False)])
 class TPUStrategyTest(test.TestCase, parameterized.TestCase):
 
+  def test_handle_in_cross_replica_context(self, enable_packed_var):
+    strategy = get_tpu_strategy(enable_packed_var)
+    with strategy.scope():
+      v = variables.Variable(1.0)
+
+    @def_function.function
+    def func():
+      self.assertEndsWith(v.handle.device, "device:TPU:0")
+      return v + 1.0
+
+    ret = func()
+    self.assertAllEqual(ret, 2.0)
+
   def test_function_compile_with_xla(self, enable_packed_var):
     strategy = get_tpu_strategy(enable_packed_var)
     with strategy.scope():
diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py
index 3094f74372e..dbe1e1ff363 100644
--- a/tensorflow/python/distribute/tpu_values.py
+++ b/tensorflow/python/distribute/tpu_values.py
@@ -115,7 +115,11 @@ class TPUVariableMixin(object):
     # If we're in a tpu.rewrite(), return the replicated handle.
     tpu_context = enclosing_tpu_context()
     if tpu_context is None or context.executing_eagerly():
-      return self._get_on_device_or_primary().handle
+      var = self._get_on_device_or_primary()
+      if isinstance(var, packed.PackedVarAndDevice):
+        return var.on_device_handle()
+      else:
+        return var.handle
     else:
       is_packed = self._packed_var is not None
       val = self._values