From b70b14a4624f6a06d394001f8d1fcd5d6d25d531 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Mon, 23 Mar 2020 10:17:25 -0700 Subject: [PATCH] Only create TPU replicated variable handle in graph mode. PiperOrigin-RevId: 302461289 Change-Id: I4923d3db3e59db45e95a7a52c0c60fb42b3ee911 --- tensorflow/python/distribute/tpu_values.py | 2 +- tensorflow/python/distribute/values_test.py | 33 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py index 5ac2a11f82b..9d0719f34b4 100644 --- a/tensorflow/python/distribute/tpu_values.py +++ b/tensorflow/python/distribute/tpu_values.py @@ -112,7 +112,7 @@ class TPUVariableMixin(object): def handle(self): # If we're in a tpu.rewrite(), return the replicated handle. tpu_context = enclosing_tpu_context() - if tpu_context is None: + if tpu_context is None or context.executing_eagerly(): return self._get_closest().handle else: return tpu_context.get_replicated_var_handle(self._handle_id, diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 0c7b3dffd2b..685dbaf4d40 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -990,6 +990,39 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32), per_replica_results) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + ], + mode=["eager"])) + def testInitScope(self, distribution): + + class C(object): + pass + + obj = C() + obj.w = None + obj.v = None + + @def_function.function + def assign(): + with ops.init_scope(): + if obj.w is None: + obj.w = variables_lib.Variable( + 0, aggregation=variables_lib.VariableAggregation.MEAN) + obj.v = variables_lib.Variable( + obj.w.read_value(), + aggregation=variables_lib.VariableAggregation.MEAN) + + return obj.v.assign_add(2) + + per_replica_results = self.evaluate( + distribution.experimental_local_results(distribution.run(assign))) + self.assertAllEqual([2, 2], per_replica_results) + @combinations.generate( combinations.combine( distribution=[