Return the primary handle when it's in graph mode and not under a tpu context.

PiperOrigin-RevId: 347756732
Change-Id: Iaaaf9052797d87ba768c5f65bd9a6e1adfb1b6da
This commit is contained in:
Yujing Zhang 2020-12-15 22:05:53 -08:00 committed by TensorFlower Gardener
parent ddad8c6ecc
commit 6136078913
3 changed files with 22 additions and 1 deletions

View File

@ -282,6 +282,10 @@ class PackedVarAndDevice(object):
with ops.device(self._device):
return self._var.handle
def on_device_handle(self):
with ops.device(self._device):
return self._var.get_var_on_current_device().handle
@property
def op(self):
with ops.device(self._device):

View File

@ -167,6 +167,19 @@ class TPUTest(test.TestCase):
@parameterized.named_parameters([("PackedVar", True), ("", False)])
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
def test_handle_in_cross_replica_context(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():
v = variables.Variable(1.0)
@def_function.function
def func():
self.assertEndsWith(v.handle.device, "device:TPU:0")
return v + 1.0
ret = func()
self.assertAllEqual(ret, 2.0)
def test_function_compile_with_xla(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
with strategy.scope():

View File

@ -115,7 +115,11 @@ class TPUVariableMixin(object):
# If we're in a tpu.rewrite(), return the replicated handle.
tpu_context = enclosing_tpu_context()
if tpu_context is None or context.executing_eagerly():
return self._get_on_device_or_primary().handle
var = self._get_on_device_or_primary()
if isinstance(var, packed.PackedVarAndDevice):
return var.on_device_handle()
else:
return var.handle
else:
is_packed = self._packed_var is not None
val = self._values