diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py index 5ac2a11f82b..9d0719f34b4 100644 --- a/tensorflow/python/distribute/tpu_values.py +++ b/tensorflow/python/distribute/tpu_values.py @@ -112,7 +112,7 @@ class TPUVariableMixin(object): def handle(self): # If we're in a tpu.rewrite(), return the replicated handle. tpu_context = enclosing_tpu_context() - if tpu_context is None: + if tpu_context is None or context.executing_eagerly(): return self._get_closest().handle else: return tpu_context.get_replicated_var_handle(self._handle_id, diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 0c7b3dffd2b..685dbaf4d40 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -990,6 +990,39 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32), per_replica_results) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + ], + mode=["eager"])) + def testInitScope(self, distribution): + + class C(object): + pass + + obj = C() + obj.w = None + obj.v = None + + @def_function.function + def assign(): + with ops.init_scope(): + if obj.w is None: + obj.w = variables_lib.Variable( + 0, aggregation=variables_lib.VariableAggregation.MEAN) + obj.v = variables_lib.Variable( + obj.w.read_value(), + aggregation=variables_lib.VariableAggregation.MEAN) + + return obj.v.assign_add(2) + + per_replica_results = self.evaluate( + distribution.experimental_local_results(distribution.run(assign))) + self.assertAllEqual([2, 2], per_replica_results) + @combinations.generate( combinations.combine( distribution=[