Disallow saving if the function cannot be used for inference
With distribution strategy, traced ConcreteFunctions may contain training specific logics that assumes the variable is a distributed variable. Such functions cannot be used for inference. Since we do not know if such ConcreteFunction will be saved for inference or not, we always mark them as unsaveable unless it's traced under a save context. The user can tf.function instead, which can be retraced in saving. Impacted usages: - MultiWorkerMirroredStrategy - Reading a synchronization=ON_READ variable. E.g. a batch norm layer. - MultiWorkerMirroredStrategy, MirroredStrategy, TPUStrategy - Updating a variable. - Reading a synchronization=ON_READ aggregation=SUM variable. It's TBD if we also need to mark functions that use packed handle as unsaveable. They do contain TPU:0 device annotations but with soft placement it may not be a problem. PiperOrigin-RevId: 337438256 Change-Id: Ie89d0d6beb3e71d3ebbb867d1f91f2953468840c
This commit is contained in:
parent
27cdf9fa33
commit
380478ff5f
@ -68,10 +68,12 @@ py_library(
|
||||
":collective_util",
|
||||
":cross_device_utils",
|
||||
":device_util",
|
||||
":distribute_utils",
|
||||
":ps_values",
|
||||
":reduce_util",
|
||||
":tpu_values",
|
||||
":values",
|
||||
":values_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:device_lib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
@ -82,8 +84,10 @@ py_library(
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:executor",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
"@enum34_archive//:enum",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
@ -1175,6 +1179,7 @@ distribute_py_test(
|
||||
"noasan", # TODO(b/337374867) fails with -fsanitize=null
|
||||
],
|
||||
deps = [
|
||||
":collective_all_reduce_strategy",
|
||||
":combinations",
|
||||
":distribute_lib",
|
||||
":distribute_utils",
|
||||
@ -1186,7 +1191,6 @@ distribute_py_test(
|
||||
":tpu_strategy",
|
||||
":tpu_values",
|
||||
":values",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
@ -1196,7 +1200,6 @@ distribute_py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:indexed_slices",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
@ -1207,14 +1210,12 @@ distribute_py_test(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/saved_model:save",
|
||||
"//tensorflow/python/saved_model:save_context",
|
||||
"//tensorflow/python/saved_model:save_options",
|
||||
"//tensorflow/python/saved_model/model_utils:mode_keys",
|
||||
"//tensorflow/python/tpu:tpu_lib",
|
||||
"//tensorflow/python/types",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.distribute import ps_values
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.distribute import values_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import executor as executor_lib
|
||||
@ -1063,6 +1064,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
|
||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||
experimental_hints):
|
||||
values_util.mark_as_unsaveable()
|
||||
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
|
||||
experimental_hints)[0]
|
||||
devices = get_devices_from(destinations)
|
||||
@ -1094,6 +1096,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
|
||||
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
||||
experimental_hints):
|
||||
values_util.mark_as_unsaveable()
|
||||
all_devices_match = _all_devices_match(value_destination_pairs)
|
||||
if all_devices_match:
|
||||
return self._batch_all_reduce(reduce_op,
|
||||
@ -1223,6 +1226,7 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
|
||||
def _gather_implementation(self, per_replica_value, destinations, axis,
|
||||
experimental_hints):
|
||||
values_util.mark_as_unsaveable()
|
||||
all_gathered = self._batch_all_gather([per_replica_value], axis,
|
||||
experimental_hints)[0]
|
||||
devices = get_devices_from(destinations)
|
||||
|
@ -874,6 +874,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
Returns:
|
||||
Updated variable or `tf.Operation`.
|
||||
"""
|
||||
values_util.mark_as_unsaveable()
|
||||
return self.distribute_strategy.extended.update(
|
||||
self, update_fn, args=(value,), kwargs=kwargs, group=True)
|
||||
|
||||
@ -1155,6 +1156,7 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if (ds_context.in_cross_replica_context() and
|
||||
not values_util.in_replica_update_context()):
|
||||
values_util.mark_as_unsaveable()
|
||||
return values_util.on_read_assign_sub_cross_replica(
|
||||
self, value, read_value=read_value)
|
||||
else:
|
||||
@ -1167,6 +1169,7 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if (ds_context.in_cross_replica_context() and
|
||||
not values_util.in_replica_update_context()):
|
||||
values_util.mark_as_unsaveable()
|
||||
return values_util.on_read_assign_add_cross_replica(
|
||||
self, value, read_value=read_value)
|
||||
else:
|
||||
@ -1179,6 +1182,7 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
if (ds_context.in_cross_replica_context() and
|
||||
not values_util.in_replica_update_context()):
|
||||
values_util.mark_as_unsaveable()
|
||||
return values_util.on_read_assign_cross_replica(
|
||||
self, value, read_value=read_value)
|
||||
else:
|
||||
@ -1243,7 +1247,8 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
# Consider returning a tensor value here to make the return value of
|
||||
# _get_cross_replica consistent.
|
||||
return self._get_replica(0)
|
||||
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
values_util.mark_as_unsaveable()
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
return self._distribute_strategy.reduce(
|
||||
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
|
||||
@ -1400,9 +1405,10 @@ class OnReadPolicy(VariablePolicy):
|
||||
def _get_cross_replica(self, var):
|
||||
if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return var._get_replica(0) # pylint: disable=protected-access
|
||||
|
||||
if self._aggregation == vs.VariableAggregation.SUM:
|
||||
values_util.mark_as_unsaveable()
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
return var.distribute_strategy.reduce(
|
||||
return var.distribute_strategy.reduce(
|
||||
reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
|
||||
var,
|
||||
axis=None)
|
||||
@ -1421,6 +1427,7 @@ class OnReadPolicy(VariablePolicy):
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if (ds_context.in_cross_replica_context() and
|
||||
not values_util.in_replica_update_context()):
|
||||
values_util.mark_as_unsaveable()
|
||||
return values_util.on_read_assign_sub_cross_replica(
|
||||
var, value, read_value=read_value)
|
||||
else:
|
||||
@ -1434,6 +1441,7 @@ class OnReadPolicy(VariablePolicy):
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if (ds_context.in_cross_replica_context() and
|
||||
not values_util.in_replica_update_context()):
|
||||
values_util.mark_as_unsaveable()
|
||||
return values_util.on_read_assign_add_cross_replica(
|
||||
var, value, read_value=read_value)
|
||||
else:
|
||||
@ -1445,6 +1453,7 @@ class OnReadPolicy(VariablePolicy):
|
||||
with ds_context.enter_or_assert_strategy(var.distribute_strategy):
|
||||
if (ds_context.in_cross_replica_context() and
|
||||
not values_util.in_replica_update_context()):
|
||||
values_util.mark_as_unsaveable()
|
||||
return values_util.on_read_assign_cross_replica(var, value,
|
||||
read_value=read_value)
|
||||
else:
|
||||
|
@ -56,6 +56,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model import save_context
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
@ -825,6 +826,67 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
def testUnsaveable(self, distribution, synchronization, aggregation, mode):
|
||||
if isinstance(distribution.extended,
|
||||
parameter_server_strategy.ParameterServerStrategyExtended):
|
||||
self.skipTest("n/a: not appliable to AggregatingVariable")
|
||||
if (isinstance(distribution,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
|
||||
and mode == "graph"):
|
||||
self.skipTest("MWMS combinations tests do not work well in graph mode.")
|
||||
with distribution.scope():
|
||||
v = variables_lib.Variable([1., 1.],
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation)
|
||||
|
||||
with self.cached_session():
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
|
||||
export_dir = self.get_temp_dir()
|
||||
|
||||
def _assert_unsaveable(f):
|
||||
# Ignore if it cannot be traced. Certain combinations are not supported or
|
||||
# yet or not allowed.
|
||||
try:
|
||||
f = def_function.function(f).get_concrete_function()
|
||||
except (NotImplementedError, ValueError):
|
||||
return
|
||||
with self.assertRaisesRegex(ValueError, "f_with_input_signature"):
|
||||
save.save(v, export_dir, signatures=f)
|
||||
|
||||
_assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.])))
|
||||
_assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.])))
|
||||
_assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.])))
|
||||
_assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0])))
|
||||
_assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0])))
|
||||
_assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0])))
|
||||
_assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0])))
|
||||
_assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0])))
|
||||
_assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0])))
|
||||
_assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0])))
|
||||
# Reading a ON_READ variable should be unsaveable if either:
|
||||
# 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM.
|
||||
# 2) aggregation is SUM.
|
||||
if (synchronization == variables_lib.VariableSynchronization.ON_READ and
|
||||
(aggregation == variables_lib.VariableAggregation.SUM or
|
||||
(isinstance(distribution.extended,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceExtended)
|
||||
and aggregation == variables_lib.VariableAggregation.MEAN))):
|
||||
_assert_unsaveable(v.read_value)
|
||||
_assert_unsaveable(v.value)
|
||||
_assert_unsaveable(lambda: ops.convert_to_tensor(v))
|
||||
else:
|
||||
# Otherwise reading a variable should be saveable.
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
v.read_value()
|
||||
v.value()
|
||||
return ops.convert_to_tensor(v)
|
||||
|
||||
with self.cached_session():
|
||||
save.save(v, export_dir, signatures=f.get_concrete_function())
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
|
@ -371,3 +371,23 @@ def is_saving_non_distributed():
|
||||
options = save_context.get_save_options()
|
||||
return (options.experimental_variable_policy !=
|
||||
save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)
|
||||
|
||||
|
||||
def mark_as_unsaveable():
|
||||
"""Marks the function as unsaveable if not inside save context."""
|
||||
if ops.inside_function() and not save_context.in_save_context():
|
||||
ops.get_default_graph().mark_as_unsaveable("""
|
||||
ConcreteFunction that uses distributed variables in certain way cannot be saved.
|
||||
If you're saving with
|
||||
|
||||
tf.saved_model.save(..., signatures=f.get_concrete_function())
|
||||
|
||||
do
|
||||
|
||||
@tf.function(input_signature=...)
|
||||
def f_with_input_signature():
|
||||
...
|
||||
|
||||
tf.saved_model.save(..., signatures=f_with_input_signature)`
|
||||
|
||||
instead.""")
|
||||
|
@ -612,7 +612,6 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/ops/numpy_ops:numpy",
|
||||
"//tensorflow/python/saved_model:save_context",
|
||||
"//tensorflow/python/saved_model:save_options",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -589,7 +589,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
experimental_variable_policy=save_options.VariablePolicy.NONE)):
|
||||
func_d = func.get_concrete_function(constant_op.constant(2.))
|
||||
|
||||
self.assertIs(func_a, func_c)
|
||||
self.assertIsNot(func_a, func_c)
|
||||
self.assertIsNot(func_a, func_d)
|
||||
|
||||
def testInitializationInNestedCall(self):
|
||||
|
@ -67,7 +67,6 @@ from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler import trace
|
||||
from tensorflow.python.saved_model import save_context
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import function_utils
|
||||
from tensorflow.python.util import lazy_loader
|
||||
@ -3177,10 +3176,7 @@ class Function(object):
|
||||
variable_policy = (
|
||||
save_context.get_save_options().experimental_variable_policy)
|
||||
else:
|
||||
# With EXPAND_DISTRIBUTED_VARIABLES the variables have the same behavior
|
||||
# in and out of saving. We use EXPAND_DISTRIBUTED_VARIABLES so that if the
|
||||
# user saves with it, there's no need to retrace the functions.
|
||||
variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
|
||||
variable_policy = None
|
||||
|
||||
return (parent_graph, device_functions, colocation_stack,
|
||||
in_cross_replica_context, variable_policy, xla_context_id)
|
||||
|
Loading…
Reference in New Issue
Block a user