Only create TPU replicated variable handle in graph mode.
PiperOrigin-RevId: 302461289 Change-Id: I4923d3db3e59db45e95a7a52c0c60fb42b3ee911
This commit is contained in:
parent
b71e31d545
commit
b70b14a462
@ -112,7 +112,7 @@ class TPUVariableMixin(object):
|
||||
def handle(self):
|
||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||
tpu_context = enclosing_tpu_context()
|
||||
if tpu_context is None:
|
||||
if tpu_context is None or context.executing_eagerly():
|
||||
return self._get_closest().handle
|
||||
else:
|
||||
return tpu_context.get_replicated_var_handle(self._handle_id,
|
||||
|
@ -990,6 +990,39 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||
array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
|
||||
per_replica_results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
mode=["eager"]))
|
||||
def testInitScope(self, distribution):
|
||||
|
||||
class C(object):
|
||||
pass
|
||||
|
||||
obj = C()
|
||||
obj.w = None
|
||||
obj.v = None
|
||||
|
||||
@def_function.function
|
||||
def assign():
|
||||
with ops.init_scope():
|
||||
if obj.w is None:
|
||||
obj.w = variables_lib.Variable(
|
||||
0, aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
obj.v = variables_lib.Variable(
|
||||
obj.w.read_value(),
|
||||
aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
|
||||
return obj.v.assign_add(2)
|
||||
|
||||
per_replica_results = self.evaluate(
|
||||
distribution.experimental_local_results(distribution.run(assign)))
|
||||
self.assertAllEqual([2, 2], per_replica_results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
|
Loading…
Reference in New Issue
Block a user