Always retrace in tf.saved_model.save
Previously we conditionally add variable policy to tf.function cache key. If the function wasn't traced under strategy.scope() nor strategy.run(), we assumed strategy wasn't involved. That avoid retracing when saving for non-distributed users. However, it turns out to have some bugs. Checking Graph._distribution_strategy_stack is not enough to tell if the function involves dist strategies, since the function can enter strategy.scope() inside the function body, or can use a distributed/sharded variable previously created under strategy. I believe it can be quite tricky to have it bug free while do not retrace for non dist strat users. Since the retracing cost for non dist strat users is usually reasonable, I'm inclined to just make it simple. One alternative is to set a "saving-sensitive" bit after the function is traced, but we also need to propagate this bits along function call stacks, which is complicated. PiperOrigin-RevId: 336471980 Change-Id: I4f2fbc6cbc2ea18f5c8afc1f8c1f985edf959047
This commit is contained in:
parent
66731f587f
commit
58febb4500
@ -145,7 +145,7 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies, mode=["eager"]))
|
||||
def testRetraceOnSaving(self, distribution):
|
||||
def testRetraceOnSavingFirstTraceInScope(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variables.Variable(0.)
|
||||
|
||||
@ -167,6 +167,31 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
func()
|
||||
self.assertEqual(prev_tracing_count, tracing_count[0])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies, mode=["eager"]))
|
||||
def testRetraceOnSavingFirstTraceOutsideScope(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variables.Variable(0.)
|
||||
|
||||
tracing_count = [0]
|
||||
|
||||
@def_function.function
|
||||
def func():
|
||||
tracing_count[0] += 1
|
||||
return v + 1.
|
||||
|
||||
func()
|
||||
prev_tracing_count = tracing_count[0]
|
||||
with save_context.save_context(save_options.SaveOptions()):
|
||||
func()
|
||||
self.assertEqual(prev_tracing_count + 1, tracing_count[0])
|
||||
|
||||
prev_tracing_count = tracing_count[0]
|
||||
with save_context.save_context(save_options.SaveOptions()):
|
||||
func()
|
||||
self.assertEqual(prev_tracing_count, tracing_count[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
@ -581,10 +581,18 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertIs(func_a, func_b)
|
||||
|
||||
with save_context.save_context(save_options.SaveOptions()):
|
||||
with save_context.save_context(
|
||||
save_options.SaveOptions(experimental_variable_policy=save_options
|
||||
.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)):
|
||||
func_c = func.get_concrete_function(constant_op.constant(2.))
|
||||
|
||||
with save_context.save_context(
|
||||
save_options.SaveOptions(
|
||||
experimental_variable_policy=save_options.VariablePolicy.NONE)):
|
||||
func_d = func.get_concrete_function(constant_op.constant(2.))
|
||||
|
||||
self.assertIs(func_a, func_c)
|
||||
self.assertIsNot(func_a, func_d)
|
||||
|
||||
def testInitializationInNestedCall(self):
|
||||
v_holder = []
|
||||
|
@ -2941,10 +2941,6 @@ class Function(object):
|
||||
self._experimental_compile = experimental_compile
|
||||
self._experimental_follow_type_hints = experimental_follow_type_hints
|
||||
|
||||
# A boolean indicating whether the function has been traced with
|
||||
# distribution strategy.
|
||||
self._traced_with_distribution_strategy = False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls a graph function specialized to the inputs."""
|
||||
with self._lock:
|
||||
@ -3177,18 +3173,13 @@ class Function(object):
|
||||
except (AttributeError, IndexError):
|
||||
pass
|
||||
|
||||
# If the function has been traced with a distribution strategy, it might
|
||||
# need to be retraced at saving time as DistributedVariable created under
|
||||
# distribution strategy may want different tracing behavior at training and
|
||||
# saving, e.g, it wants to resolve to the primary component at saving time,
|
||||
# but wants resolve to the component residing in the current device at
|
||||
# training time. We achieve this by adding variable_policy to the function
|
||||
# cache key.
|
||||
if save_context.in_save_context(
|
||||
) and self._traced_with_distribution_strategy:
|
||||
if save_context.in_save_context():
|
||||
variable_policy = (
|
||||
save_context.get_save_options().experimental_variable_policy)
|
||||
else:
|
||||
# With EXPAND_DISTRIBUTED_VARIABLES the variables have the same behavior
|
||||
# in and out of saving. We use EXPAND_DISTRIBUTED_VARIABLES so that if the
|
||||
# user saves with it, there's no need to retrace the functions.
|
||||
variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
|
||||
|
||||
return (parent_graph, device_functions, colocation_stack,
|
||||
@ -3380,9 +3371,6 @@ class Function(object):
|
||||
graph_function = self._create_graph_function(args, kwargs)
|
||||
self._function_cache.primary[cache_key] = graph_function
|
||||
|
||||
if ops.get_default_graph()._distribution_strategy_stack:
|
||||
self._traced_with_distribution_strategy = True
|
||||
|
||||
return graph_function, filtered_flat_args
|
||||
|
||||
|
||||
|
@ -745,7 +745,6 @@ class SavingOptionsTest(test.TestCase):
|
||||
root.f = def_function.function(
|
||||
lambda x: 2. * x,
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
root.f(constant_op.constant(1.))
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
options = save_options.SaveOptions(function_aliases={
|
||||
"my_func": root.f,
|
||||
|
Loading…
Reference in New Issue
Block a user