Return the primary handle when it's in graph mode and not under a tpu context.
PiperOrigin-RevId: 347756732 Change-Id: Iaaaf9052797d87ba768c5f65bd9a6e1adfb1b6da
This commit is contained in:
parent
ddad8c6ecc
commit
6136078913
tensorflow/python/distribute
@ -282,6 +282,10 @@ class PackedVarAndDevice(object):
|
||||
with ops.device(self._device):
|
||||
return self._var.handle
|
||||
|
||||
def on_device_handle(self):
|
||||
with ops.device(self._device):
|
||||
return self._var.get_var_on_current_device().handle
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
with ops.device(self._device):
|
||||
|
@ -167,6 +167,19 @@ class TPUTest(test.TestCase):
|
||||
@parameterized.named_parameters([("PackedVar", True), ("", False)])
|
||||
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_handle_in_cross_replica_context(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
with strategy.scope():
|
||||
v = variables.Variable(1.0)
|
||||
|
||||
@def_function.function
|
||||
def func():
|
||||
self.assertEndsWith(v.handle.device, "device:TPU:0")
|
||||
return v + 1.0
|
||||
|
||||
ret = func()
|
||||
self.assertAllEqual(ret, 2.0)
|
||||
|
||||
def test_function_compile_with_xla(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
with strategy.scope():
|
||||
|
@ -115,7 +115,11 @@ class TPUVariableMixin(object):
|
||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||
tpu_context = enclosing_tpu_context()
|
||||
if tpu_context is None or context.executing_eagerly():
|
||||
return self._get_on_device_or_primary().handle
|
||||
var = self._get_on_device_or_primary()
|
||||
if isinstance(var, packed.PackedVarAndDevice):
|
||||
return var.on_device_handle()
|
||||
else:
|
||||
return var.handle
|
||||
else:
|
||||
is_packed = self._packed_var is not None
|
||||
val = self._values
|
||||
|
Loading…
Reference in New Issue
Block a user