Only create TPU replicated variable handle in graph mode.
PiperOrigin-RevId: 302461289 Change-Id: I4923d3db3e59db45e95a7a52c0c60fb42b3ee911
This commit is contained in:
parent
b71e31d545
commit
b70b14a462
@ -112,7 +112,7 @@ class TPUVariableMixin(object):
|
|||||||
def handle(self):
|
def handle(self):
|
||||||
# If we're in a tpu.rewrite(), return the replicated handle.
|
# If we're in a tpu.rewrite(), return the replicated handle.
|
||||||
tpu_context = enclosing_tpu_context()
|
tpu_context = enclosing_tpu_context()
|
||||||
if tpu_context is None:
|
if tpu_context is None or context.executing_eagerly():
|
||||||
return self._get_closest().handle
|
return self._get_closest().handle
|
||||||
else:
|
else:
|
||||||
return tpu_context.get_replicated_var_handle(self._handle_id,
|
return tpu_context.get_replicated_var_handle(self._handle_id,
|
||||||
|
@ -990,6 +990,39 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
|
array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
|
||||||
per_replica_results)
|
per_replica_results)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||||
|
],
|
||||||
|
mode=["eager"]))
|
||||||
|
def testInitScope(self, distribution):
|
||||||
|
|
||||||
|
class C(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
obj = C()
|
||||||
|
obj.w = None
|
||||||
|
obj.v = None
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def assign():
|
||||||
|
with ops.init_scope():
|
||||||
|
if obj.w is None:
|
||||||
|
obj.w = variables_lib.Variable(
|
||||||
|
0, aggregation=variables_lib.VariableAggregation.MEAN)
|
||||||
|
obj.v = variables_lib.Variable(
|
||||||
|
obj.w.read_value(),
|
||||||
|
aggregation=variables_lib.VariableAggregation.MEAN)
|
||||||
|
|
||||||
|
return obj.v.assign_add(2)
|
||||||
|
|
||||||
|
per_replica_results = self.evaluate(
|
||||||
|
distribution.experimental_local_results(distribution.run(assign)))
|
||||||
|
self.assertAllEqual([2, 2], per_replica_results)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=[
|
distribution=[
|
||||||
|
Loading…
Reference in New Issue
Block a user