From 380478ff5faf28bce987b3924677b214256adc3f Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Thu, 15 Oct 2020 21:00:35 -0700 Subject: [PATCH] Disallow saving if the function cannot be used for inference With distribution strategy, traced ConcreteFunctions may contain training specific logics that assumes the variable is a distributed variable. Such functions cannot be used for inference. Since we do not know if such ConcreteFunction will be saved for inference or not, we always mark them as unsaveable unless it's traced under a save context. The user can tf.function instead, which can be retraced in saving. Impacted usages: - MultiWorkerMirroredStrategy - Reading a synchronization=ON_READ variable. E.g. a batch norm layer. - MultiWorkerMirroredStrategy, MirroredStrategy, TPUStrategy - Updating a variable. - Reading a synchronization=ON_READ aggregation=SUM variable. It's TBD if we also need to mark functions that use packed handle as unsaveable. They do contain TPU:0 device annotations but with soft placement it may not be a problem. PiperOrigin-RevId: 337438256 Change-Id: Ie89d0d6beb3e71d3ebbb867d1f91f2953468840c --- tensorflow/python/distribute/BUILD | 11 ++-- .../python/distribute/cross_device_ops.py | 4 ++ tensorflow/python/distribute/values.py | 15 ++++- tensorflow/python/distribute/values_test.py | 62 +++++++++++++++++++ tensorflow/python/distribute/values_util.py | 20 ++++++ tensorflow/python/eager/BUILD | 1 - tensorflow/python/eager/def_function_test.py | 2 +- tensorflow/python/eager/function.py | 6 +- 8 files changed, 106 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index a8e730a3e9c..f9b2527df53 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -68,10 +68,12 @@ py_library( ":collective_util", ":cross_device_utils", ":device_util", + ":distribute_utils", ":ps_values", ":reduce_util", ":tpu_values", ":values", + ":values_util", "//tensorflow/python:array_ops", "//tensorflow/python:device_lib", "//tensorflow/python:framework_ops", @@ -82,8 +84,10 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:tf_export", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:executor", "//tensorflow/tools/docs:doc_controls", + "@enum34_archive//:enum", "@six_archive//:six", ], ) @@ -1175,6 +1179,7 @@ distribute_py_test( "noasan", # TODO(b/337374867) fails with -fsanitize=null ], deps = [ + ":collective_all_reduce_strategy", ":combinations", ":distribute_lib", ":distribute_utils", @@ -1186,7 +1191,6 @@ distribute_py_test( ":tpu_strategy", ":tpu_values", ":values", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:constant_op", @@ -1196,7 +1200,6 @@ distribute_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:indexed_slices", "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", "//tensorflow/python:saver", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", @@ -1207,14 +1210,12 @@ distribute_py_test( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", + "//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:save_context", "//tensorflow/python/saved_model:save_options", - "//tensorflow/python/saved_model/model_utils:mode_keys", - "//tensorflow/python/tpu:tpu_lib", "//tensorflow/python/types", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index 6d2a4e16f84..f4c9101ea2c 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -34,6 +34,7 @@ from tensorflow.python.distribute import ps_values from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import values as value_lib +from tensorflow.python.distribute import values_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import executor as executor_lib @@ -1063,6 +1064,7 @@ class CollectiveAllReduce(CrossDeviceOps): def reduce_implementation(self, reduce_op, per_replica_value, destinations, experimental_hints): + values_util.mark_as_unsaveable() all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value], experimental_hints)[0] devices = get_devices_from(destinations) @@ -1094,6 +1096,7 @@ class CollectiveAllReduce(CrossDeviceOps): def batch_reduce_implementation(self, reduce_op, value_destination_pairs, experimental_hints): + values_util.mark_as_unsaveable() all_devices_match = _all_devices_match(value_destination_pairs) if all_devices_match: return self._batch_all_reduce(reduce_op, @@ -1223,6 +1226,7 @@ class CollectiveAllReduce(CrossDeviceOps): def _gather_implementation(self, per_replica_value, destinations, axis, experimental_hints): + values_util.mark_as_unsaveable() all_gathered = self._batch_all_gather([per_replica_value], axis, experimental_hints)[0] devices = get_devices_from(destinations) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 1464941523b..b90fd24b6e0 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -874,6 +874,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, Returns: Updated variable or `tf.Operation`. """ + values_util.mark_as_unsaveable() return self.distribute_strategy.extended.update( self, update_fn, args=(value,), kwargs=kwargs, group=True) @@ -1155,6 +1156,7 @@ class SyncOnReadVariable(DistributedVariable): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_sub_cross_replica( self, value, read_value=read_value) else: @@ -1167,6 +1169,7 @@ class SyncOnReadVariable(DistributedVariable): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_add_cross_replica( self, value, read_value=read_value) else: @@ -1179,6 +1182,7 @@ class SyncOnReadVariable(DistributedVariable): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_cross_replica( self, value, read_value=read_value) else: @@ -1243,7 +1247,8 @@ class SyncOnReadVariable(DistributedVariable): # Consider returning a tensor value here to make the return value of # _get_cross_replica consistent. return self._get_replica(0) - + if self._aggregation == vs.VariableAggregation.SUM: + values_util.mark_as_unsaveable() with ds_context.enter_or_assert_strategy(self._distribute_strategy): return self._distribute_strategy.reduce( reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), @@ -1400,9 +1405,10 @@ class OnReadPolicy(VariablePolicy): def _get_cross_replica(self, var): if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: return var._get_replica(0) # pylint: disable=protected-access - + if self._aggregation == vs.VariableAggregation.SUM: + values_util.mark_as_unsaveable() with ds_context.enter_or_assert_strategy(var.distribute_strategy): - return var.distribute_strategy.reduce( + return var.distribute_strategy.reduce( reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), var, axis=None) @@ -1421,6 +1427,7 @@ class OnReadPolicy(VariablePolicy): with ds_context.enter_or_assert_strategy(var.distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_sub_cross_replica( var, value, read_value=read_value) else: @@ -1434,6 +1441,7 @@ class OnReadPolicy(VariablePolicy): with ds_context.enter_or_assert_strategy(var.distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_add_cross_replica( var, value, read_value=read_value) else: @@ -1445,6 +1453,7 @@ class OnReadPolicy(VariablePolicy): with ds_context.enter_or_assert_strategy(var.distribute_strategy): if (ds_context.in_cross_replica_context() and not values_util.in_replica_update_context()): + values_util.mark_as_unsaveable() return values_util.on_read_assign_cross_replica(var, value, read_value=read_value) else: diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 8a9f0acbd75..1f9bef137d5 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -56,6 +56,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.saved_model import save from tensorflow.python.saved_model import save_context from tensorflow.python.saved_model import save_options from tensorflow.python.training import saver as saver_lib @@ -825,6 +826,67 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): # pylint: enable=g-long-lambda + def testUnsaveable(self, distribution, synchronization, aggregation, mode): + if isinstance(distribution.extended, + parameter_server_strategy.ParameterServerStrategyExtended): + self.skipTest("n/a: not appliable to AggregatingVariable") + if (isinstance(distribution, + collective_all_reduce_strategy.CollectiveAllReduceStrategy) + and mode == "graph"): + self.skipTest("MWMS combinations tests do not work well in graph mode.") + with distribution.scope(): + v = variables_lib.Variable([1., 1.], + synchronization=synchronization, + aggregation=aggregation) + + with self.cached_session(): + self.evaluate(variables_lib.global_variables_initializer()) + + export_dir = self.get_temp_dir() + + def _assert_unsaveable(f): + # Ignore if it cannot be traced. Certain combinations are not supported or + # yet or not allowed. + try: + f = def_function.function(f).get_concrete_function() + except (NotImplementedError, ValueError): + return + with self.assertRaisesRegex(ValueError, "f_with_input_signature"): + save.save(v, export_dir, signatures=f) + + _assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.]))) + _assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.]))) + _assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.]))) + _assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0]))) + _assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0]))) + # Reading a ON_READ variable should be unsaveable if either: + # 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM. + # 2) aggregation is SUM. + if (synchronization == variables_lib.VariableSynchronization.ON_READ and + (aggregation == variables_lib.VariableAggregation.SUM or + (isinstance(distribution.extended, + collective_all_reduce_strategy.CollectiveAllReduceExtended) + and aggregation == variables_lib.VariableAggregation.MEAN))): + _assert_unsaveable(v.read_value) + _assert_unsaveable(v.value) + _assert_unsaveable(lambda: ops.convert_to_tensor(v)) + else: + # Otherwise reading a variable should be saveable. + + @def_function.function + def f(): + v.read_value() + v.value() + return ops.convert_to_tensor(v) + + with self.cached_session(): + save.save(v, export_dir, signatures=f.get_concrete_function()) + @combinations.generate( combinations.combine( diff --git a/tensorflow/python/distribute/values_util.py b/tensorflow/python/distribute/values_util.py index 0071ee67b67..369e2435d9b 100644 --- a/tensorflow/python/distribute/values_util.py +++ b/tensorflow/python/distribute/values_util.py @@ -371,3 +371,23 @@ def is_saving_non_distributed(): options = save_context.get_save_options() return (options.experimental_variable_policy != save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES) + + +def mark_as_unsaveable(): + """Marks the function as unsaveable if not inside save context.""" + if ops.inside_function() and not save_context.in_save_context(): + ops.get_default_graph().mark_as_unsaveable(""" +ConcreteFunction that uses distributed variables in certain way cannot be saved. +If you're saving with + +tf.saved_model.save(..., signatures=f.get_concrete_function()) + +do + +@tf.function(input_signature=...) +def f_with_input_signature(): + ... + +tf.saved_model.save(..., signatures=f_with_input_signature)` + +instead.""") diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 5918940a38a..bb56146227c 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -612,7 +612,6 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/ops/numpy_ops:numpy", "//tensorflow/python/saved_model:save_context", - "//tensorflow/python/saved_model:save_options", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 59bb0adc488..42af94c6cb1 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -589,7 +589,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): experimental_variable_policy=save_options.VariablePolicy.NONE)): func_d = func.get_concrete_function(constant_op.constant(2.)) - self.assertIs(func_a, func_c) + self.assertIsNot(func_a, func_c) self.assertIsNot(func_a, func_d) def testInitializationInNestedCall(self): diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 60dd3f17024..ab32c8370af 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -67,7 +67,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.profiler import trace from tensorflow.python.saved_model import save_context -from tensorflow.python.saved_model import save_options from tensorflow.python.util import compat from tensorflow.python.util import function_utils from tensorflow.python.util import lazy_loader @@ -3177,10 +3176,7 @@ class Function(object): variable_policy = ( save_context.get_save_options().experimental_variable_policy) else: - # With EXPAND_DISTRIBUTED_VARIABLES the variables have the same behavior - # in and out of saving. We use EXPAND_DISTRIBUTED_VARIABLES so that if the - # user saves with it, there's no need to retrace the functions. - variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES + variable_policy = None return (parent_graph, device_functions, colocation_stack, in_cross_replica_context, variable_policy, xla_context_id)