Only create TPU replicated variable handle in graph mode.

PiperOrigin-RevId: 302461289
Change-Id: I4923d3db3e59db45e95a7a52c0c60fb42b3ee911
This commit is contained in:
Ruoxin Sang 2020-03-23 10:17:25 -07:00 committed by TensorFlower Gardener
parent b71e31d545
commit b70b14a462
2 changed files with 34 additions and 1 deletions

View File

@ -112,7 +112,7 @@ class TPUVariableMixin(object):
def handle(self):
# If we're in a tpu.rewrite(), return the replicated handle.
tpu_context = enclosing_tpu_context()
if tpu_context is None:
if tpu_context is None or context.executing_eagerly():
return self._get_closest().handle
else:
return tpu_context.get_replicated_var_handle(self._handle_id,

View File

@ -990,6 +990,39 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]))
def testInitScope(self, distribution):
class C(object):
pass
obj = C()
obj.w = None
obj.v = None
@def_function.function
def assign():
with ops.init_scope():
if obj.w is None:
obj.w = variables_lib.Variable(
0, aggregation=variables_lib.VariableAggregation.MEAN)
obj.v = variables_lib.Variable(
obj.w.read_value(),
aggregation=variables_lib.VariableAggregation.MEAN)
return obj.v.assign_add(2)
per_replica_results = self.evaluate(
distribution.experimental_local_results(distribution.run(assign)))
self.assertAllEqual([2, 2], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[