Disallow saving if the function cannot be used for inference

With distribution strategy, traced ConcreteFunctions may contain training specific logics that assumes the variable is a distributed variable. Such functions cannot be used for inference. Since we do not know if such ConcreteFunction will be saved for inference or not, we always mark them as unsaveable unless it's traced under a save context.

The user can tf.function instead, which can be retraced in saving.

Impacted usages:
- MultiWorkerMirroredStrategy
  - Reading a synchronization=ON_READ variable. E.g. a batch norm layer.
- MultiWorkerMirroredStrategy, MirroredStrategy, TPUStrategy
  - Updating a variable.
  - Reading a synchronization=ON_READ aggregation=SUM variable.

It's TBD if we also need to mark functions that use packed handle as unsaveable. They do contain TPU:0 device annotations but with soft placement it may not be a problem.

PiperOrigin-RevId: 337438256
Change-Id: Ie89d0d6beb3e71d3ebbb867d1f91f2953468840c
This commit is contained in:
Ran Chen 2020-10-15 21:00:35 -07:00 committed by TensorFlower Gardener
parent 27cdf9fa33
commit 380478ff5f
8 changed files with 106 additions and 15 deletions

View File

@ -68,10 +68,12 @@ py_library(
":collective_util",
":cross_device_utils",
":device_util",
":distribute_utils",
":ps_values",
":reduce_util",
":tpu_values",
":values",
":values_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:device_lib",
"//tensorflow/python:framework_ops",
@ -82,8 +84,10 @@ py_library(
"//tensorflow/python:tensor_util",
"//tensorflow/python:tf_export",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:executor",
"//tensorflow/tools/docs:doc_controls",
"@enum34_archive//:enum",
"@six_archive//:six",
],
)
@ -1175,6 +1179,7 @@ distribute_py_test(
"noasan", # TODO(b/337374867) fails with -fsanitize=null
],
deps = [
":collective_all_reduce_strategy",
":combinations",
":distribute_lib",
":distribute_utils",
@ -1186,7 +1191,6 @@ distribute_py_test(
":tpu_strategy",
":tpu_values",
":values",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:constant_op",
@ -1196,7 +1200,6 @@ distribute_py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:indexed_slices",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:saver",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
@ -1207,14 +1210,12 @@ distribute_py_test(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/saved_model:save",
"//tensorflow/python/saved_model:save_context",
"//tensorflow/python/saved_model:save_options",
"//tensorflow/python/saved_model/model_utils:mode_keys",
"//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/types",
"@absl_py//absl/testing:parameterized",
],

View File

@ -34,6 +34,7 @@ from tensorflow.python.distribute import ps_values
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import executor as executor_lib
@ -1063,6 +1064,7 @@ class CollectiveAllReduce(CrossDeviceOps):
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
experimental_hints):
values_util.mark_as_unsaveable()
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
experimental_hints)[0]
devices = get_devices_from(destinations)
@ -1094,6 +1096,7 @@ class CollectiveAllReduce(CrossDeviceOps):
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
experimental_hints):
values_util.mark_as_unsaveable()
all_devices_match = _all_devices_match(value_destination_pairs)
if all_devices_match:
return self._batch_all_reduce(reduce_op,
@ -1223,6 +1226,7 @@ class CollectiveAllReduce(CrossDeviceOps):
def _gather_implementation(self, per_replica_value, destinations, axis,
experimental_hints):
values_util.mark_as_unsaveable()
all_gathered = self._batch_all_gather([per_replica_value], axis,
experimental_hints)[0]
devices = get_devices_from(destinations)

View File

@ -874,6 +874,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
Returns:
Updated variable or `tf.Operation`.
"""
values_util.mark_as_unsaveable()
return self.distribute_strategy.extended.update(
self, update_fn, args=(value,), kwargs=kwargs, group=True)
@ -1155,6 +1156,7 @@ class SyncOnReadVariable(DistributedVariable):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if (ds_context.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_sub_cross_replica(
self, value, read_value=read_value)
else:
@ -1167,6 +1169,7 @@ class SyncOnReadVariable(DistributedVariable):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if (ds_context.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_add_cross_replica(
self, value, read_value=read_value)
else:
@ -1179,6 +1182,7 @@ class SyncOnReadVariable(DistributedVariable):
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
if (ds_context.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_cross_replica(
self, value, read_value=read_value)
else:
@ -1243,7 +1247,8 @@ class SyncOnReadVariable(DistributedVariable):
# Consider returning a tensor value here to make the return value of
# _get_cross_replica consistent.
return self._get_replica(0)
if self._aggregation == vs.VariableAggregation.SUM:
values_util.mark_as_unsaveable()
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
return self._distribute_strategy.reduce(
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
@ -1400,9 +1405,10 @@ class OnReadPolicy(VariablePolicy):
def _get_cross_replica(self, var):
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return var._get_replica(0) # pylint: disable=protected-access
if self._aggregation == vs.VariableAggregation.SUM:
values_util.mark_as_unsaveable()
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
return var.distribute_strategy.reduce(
return var.distribute_strategy.reduce(
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
var,
axis=None)
@ -1421,6 +1427,7 @@ class OnReadPolicy(VariablePolicy):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if (ds_context.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_sub_cross_replica(
var, value, read_value=read_value)
else:
@ -1434,6 +1441,7 @@ class OnReadPolicy(VariablePolicy):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if (ds_context.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_add_cross_replica(
var, value, read_value=read_value)
else:
@ -1445,6 +1453,7 @@ class OnReadPolicy(VariablePolicy):
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
if (ds_context.in_cross_replica_context() and
not values_util.in_replica_update_context()):
values_util.mark_as_unsaveable()
return values_util.on_read_assign_cross_replica(var, value,
read_value=read_value)
else:

View File

@ -56,6 +56,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import save_context
from tensorflow.python.saved_model import save_options
from tensorflow.python.training import saver as saver_lib
@ -825,6 +826,67 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
# pylint: enable=g-long-lambda
def testUnsaveable(self, distribution, synchronization, aggregation, mode):
if isinstance(distribution.extended,
parameter_server_strategy.ParameterServerStrategyExtended):
self.skipTest("n/a: not appliable to AggregatingVariable")
if (isinstance(distribution,
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
and mode == "graph"):
self.skipTest("MWMS combinations tests do not work well in graph mode.")
with distribution.scope():
v = variables_lib.Variable([1., 1.],
synchronization=synchronization,
aggregation=aggregation)
with self.cached_session():
self.evaluate(variables_lib.global_variables_initializer())
export_dir = self.get_temp_dir()
def _assert_unsaveable(f):
# Ignore if it cannot be traced. Certain combinations are not supported or
# yet or not allowed.
try:
f = def_function.function(f).get_concrete_function()
except (NotImplementedError, ValueError):
return
with self.assertRaisesRegex(ValueError, "f_with_input_signature"):
save.save(v, export_dir, signatures=f)
_assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.])))
_assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.])))
_assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.])))
_assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0])))
_assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0])))
# Reading a ON_READ variable should be unsaveable if either:
# 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM.
# 2) aggregation is SUM.
if (synchronization == variables_lib.VariableSynchronization.ON_READ and
(aggregation == variables_lib.VariableAggregation.SUM or
(isinstance(distribution.extended,
collective_all_reduce_strategy.CollectiveAllReduceExtended)
and aggregation == variables_lib.VariableAggregation.MEAN))):
_assert_unsaveable(v.read_value)
_assert_unsaveable(v.value)
_assert_unsaveable(lambda: ops.convert_to_tensor(v))
else:
# Otherwise reading a variable should be saveable.
@def_function.function
def f():
v.read_value()
v.value()
return ops.convert_to_tensor(v)
with self.cached_session():
save.save(v, export_dir, signatures=f.get_concrete_function())
@combinations.generate(
combinations.combine(

View File

@ -371,3 +371,23 @@ def is_saving_non_distributed():
options = save_context.get_save_options()
return (options.experimental_variable_policy !=
save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)
def mark_as_unsaveable():
"""Marks the function as unsaveable if not inside save context."""
if ops.inside_function() and not save_context.in_save_context():
ops.get_default_graph().mark_as_unsaveable("""
ConcreteFunction that uses distributed variables in certain way cannot be saved.
If you're saving with
tf.saved_model.save(..., signatures=f.get_concrete_function())
do
@tf.function(input_signature=...)
def f_with_input_signature():
...
tf.saved_model.save(..., signatures=f_with_input_signature)`
instead.""")

View File

@ -612,7 +612,6 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/ops/numpy_ops:numpy",
"//tensorflow/python/saved_model:save_context",
"//tensorflow/python/saved_model:save_options",
"//third_party/py/numpy",
"@six_archive//:six",
],

View File

@ -589,7 +589,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
experimental_variable_policy=save_options.VariablePolicy.NONE)):
func_d = func.get_concrete_function(constant_op.constant(2.))
self.assertIs(func_a, func_c)
self.assertIsNot(func_a, func_c)
self.assertIsNot(func_a, func_d)
def testInitializationInNestedCall(self):

View File

@ -67,7 +67,6 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import trace
from tensorflow.python.saved_model import save_context
from tensorflow.python.saved_model import save_options
from tensorflow.python.util import compat
from tensorflow.python.util import function_utils
from tensorflow.python.util import lazy_loader
@ -3177,10 +3176,7 @@ class Function(object):
variable_policy = (
save_context.get_save_options().experimental_variable_policy)
else:
# With EXPAND_DISTRIBUTED_VARIABLES the variables have the same behavior
# in and out of saving. We use EXPAND_DISTRIBUTED_VARIABLES so that if the
# user saves with it, there's no need to retrace the functions.
variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
variable_policy = None
return (parent_graph, device_functions, colocation_stack,
in_cross_replica_context, variable_policy, xla_context_id)