Always retrace in tf.saved_model.save

Previously we conditionally add variable policy to tf.function cache key. If the function wasn't traced under strategy.scope() nor strategy.run(), we assumed strategy wasn't involved. That avoid retracing when saving for non-distributed users.

However, it turns out to have some bugs. Checking Graph._distribution_strategy_stack is not enough to tell if the function involves dist strategies, since the function can enter strategy.scope() inside the function body, or can use a distributed/sharded variable previously created under strategy.

I believe it can be quite tricky to have it bug free while do not retrace for non dist strat users. Since the retracing cost for non dist strat users is usually reasonable, I'm inclined to just make it simple.

One alternative is to set a "saving-sensitive" bit after the function is traced, but we also need to propagate this bits along function call stacks, which is complicated.

PiperOrigin-RevId: 336471980
Change-Id: I4f2fbc6cbc2ea18f5c8afc1f8c1f985edf959047
This commit is contained in:
Ran Chen 2020-10-10 12:12:04 -07:00 committed by TensorFlower Gardener
parent 66731f587f
commit 58febb4500
4 changed files with 39 additions and 19 deletions

View File

@ -145,7 +145,7 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies, mode=["eager"]))
def testRetraceOnSaving(self, distribution):
def testRetraceOnSavingFirstTraceInScope(self, distribution):
with distribution.scope():
v = variables.Variable(0.)
@ -167,6 +167,31 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase):
func()
self.assertEqual(prev_tracing_count, tracing_count[0])
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies, mode=["eager"]))
def testRetraceOnSavingFirstTraceOutsideScope(self, distribution):
with distribution.scope():
v = variables.Variable(0.)
tracing_count = [0]
@def_function.function
def func():
tracing_count[0] += 1
return v + 1.
func()
prev_tracing_count = tracing_count[0]
with save_context.save_context(save_options.SaveOptions()):
func()
self.assertEqual(prev_tracing_count + 1, tracing_count[0])
prev_tracing_count = tracing_count[0]
with save_context.save_context(save_options.SaveOptions()):
func()
self.assertEqual(prev_tracing_count, tracing_count[0])
if __name__ == "__main__":
v2_compat.enable_v2_behavior()

View File

@ -581,10 +581,18 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertIs(func_a, func_b)
with save_context.save_context(save_options.SaveOptions()):
with save_context.save_context(
save_options.SaveOptions(experimental_variable_policy=save_options
.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)):
func_c = func.get_concrete_function(constant_op.constant(2.))
with save_context.save_context(
save_options.SaveOptions(
experimental_variable_policy=save_options.VariablePolicy.NONE)):
func_d = func.get_concrete_function(constant_op.constant(2.))
self.assertIs(func_a, func_c)
self.assertIsNot(func_a, func_d)
def testInitializationInNestedCall(self):
v_holder = []

View File

@ -2941,10 +2941,6 @@ class Function(object):
self._experimental_compile = experimental_compile
self._experimental_follow_type_hints = experimental_follow_type_hints
# A boolean indicating whether the function has been traced with
# distribution strategy.
self._traced_with_distribution_strategy = False
def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
with self._lock:
@ -3177,18 +3173,13 @@ class Function(object):
except (AttributeError, IndexError):
pass
# If the function has been traced with a distribution strategy, it might
# need to be retraced at saving time as DistributedVariable created under
# distribution strategy may want different tracing behavior at training and
# saving, e.g, it wants to resolve to the primary component at saving time,
# but wants resolve to the component residing in the current device at
# training time. We achieve this by adding variable_policy to the function
# cache key.
if save_context.in_save_context(
) and self._traced_with_distribution_strategy:
if save_context.in_save_context():
variable_policy = (
save_context.get_save_options().experimental_variable_policy)
else:
# With EXPAND_DISTRIBUTED_VARIABLES the variables have the same behavior
# in and out of saving. We use EXPAND_DISTRIBUTED_VARIABLES so that if the
# user saves with it, there's no need to retrace the functions.
variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
return (parent_graph, device_functions, colocation_stack,
@ -3380,9 +3371,6 @@ class Function(object):
graph_function = self._create_graph_function(args, kwargs)
self._function_cache.primary[cache_key] = graph_function
if ops.get_default_graph()._distribution_strategy_stack:
self._traced_with_distribution_strategy = True
return graph_function, filtered_flat_args

View File

@ -745,7 +745,6 @@ class SavingOptionsTest(test.TestCase):
root.f = def_function.function(
lambda x: 2. * x,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root.f(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
options = save_options.SaveOptions(function_aliases={
"my_func": root.f,