diff --git a/tensorflow/python/distribute/tf_function_test.py b/tensorflow/python/distribute/tf_function_test.py index 967abebdfb3..38d070c2788 100644 --- a/tensorflow/python/distribute/tf_function_test.py +++ b/tensorflow/python/distribute/tf_function_test.py @@ -145,7 +145,7 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.combine( distribution=strategy_combinations.all_strategies, mode=["eager"])) - def testRetraceOnSaving(self, distribution): + def testRetraceOnSavingFirstTraceInScope(self, distribution): with distribution.scope(): v = variables.Variable(0.) @@ -167,6 +167,31 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase): func() self.assertEqual(prev_tracing_count, tracing_count[0]) + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, mode=["eager"])) + def testRetraceOnSavingFirstTraceOutsideScope(self, distribution): + with distribution.scope(): + v = variables.Variable(0.) + + tracing_count = [0] + + @def_function.function + def func(): + tracing_count[0] += 1 + return v + 1. + + func() + prev_tracing_count = tracing_count[0] + with save_context.save_context(save_options.SaveOptions()): + func() + self.assertEqual(prev_tracing_count + 1, tracing_count[0]) + + prev_tracing_count = tracing_count[0] + with save_context.save_context(save_options.SaveOptions()): + func() + self.assertEqual(prev_tracing_count, tracing_count[0]) + if __name__ == "__main__": v2_compat.enable_v2_behavior() diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index ff78ee4f603..334fde3f19d 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -581,10 +581,18 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertIs(func_a, func_b) - with save_context.save_context(save_options.SaveOptions()): + with save_context.save_context( + save_options.SaveOptions(experimental_variable_policy=save_options + .VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)): func_c = func.get_concrete_function(constant_op.constant(2.)) + with save_context.save_context( + save_options.SaveOptions( + experimental_variable_policy=save_options.VariablePolicy.NONE)): + func_d = func.get_concrete_function(constant_op.constant(2.)) + self.assertIs(func_a, func_c) + self.assertIsNot(func_a, func_d) def testInitializationInNestedCall(self): v_holder = [] diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index afd5cb31374..60dd3f17024 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2941,10 +2941,6 @@ class Function(object): self._experimental_compile = experimental_compile self._experimental_follow_type_hints = experimental_follow_type_hints - # A boolean indicating whether the function has been traced with - # distribution strategy. - self._traced_with_distribution_strategy = False - def __call__(self, *args, **kwargs): """Calls a graph function specialized to the inputs.""" with self._lock: @@ -3177,18 +3173,13 @@ class Function(object): except (AttributeError, IndexError): pass - # If the function has been traced with a distribution strategy, it might - # need to be retraced at saving time as DistributedVariable created under - # distribution strategy may want different tracing behavior at training and - # saving, e.g, it wants to resolve to the primary component at saving time, - # but wants resolve to the component residing in the current device at - # training time. We achieve this by adding variable_policy to the function - # cache key. - if save_context.in_save_context( - ) and self._traced_with_distribution_strategy: + if save_context.in_save_context(): variable_policy = ( save_context.get_save_options().experimental_variable_policy) else: + # With EXPAND_DISTRIBUTED_VARIABLES the variables have the same behavior + # in and out of saving. We use EXPAND_DISTRIBUTED_VARIABLES so that if the + # user saves with it, there's no need to retrace the functions. variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES return (parent_graph, device_functions, colocation_stack, @@ -3380,9 +3371,6 @@ class Function(object): graph_function = self._create_graph_function(args, kwargs) self._function_cache.primary[cache_key] = graph_function - if ops.get_default_graph()._distribution_strategy_stack: - self._traced_with_distribution_strategy = True - return graph_function, filtered_flat_args diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index 522720096af..7521a5d8593 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -745,7 +745,6 @@ class SavingOptionsTest(test.TestCase): root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) - root.f(constant_op.constant(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") options = save_options.SaveOptions(function_aliases={ "my_func": root.f,