From 12c67c0d4797d2293de54207fa6f23492e2efeb3 Mon Sep 17 00:00:00 2001 From: Chenkai Kuang Date: Mon, 21 Dec 2020 15:38:56 -0800 Subject: [PATCH] Raise meaningful error message when loading a ShardedVariable. PiperOrigin-RevId: 348539354 Change-Id: I2c4a8466c3d1355ec8e5984ed039194c18c4305c --- tensorflow/python/distribute/BUILD | 2 ++ .../integration_test/saved_model_test.py | 20 +++++++++++++++++++ .../parameter_server_strategy_v2.py | 8 +++++++- .../python/distribute/sharded_variable.py | 19 ++++++++++++++++++ .../distribute/sharded_variable_test.py | 14 +++++++++++++ 5 files changed, 62 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 168a58f6b01..0df1dce8f9f 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -272,6 +272,7 @@ py_library( ":device_util", ":distribute_lib", ":reduce_util", + ":sharded_variable", ":shared_variable_creator", ":tpu_values", ":values", @@ -1118,6 +1119,7 @@ tf_py_test( "//tensorflow/python/compat:v2_compat", "//tensorflow/python/eager:def_function", "//tensorflow/python/module", + "//tensorflow/python/saved_model:load", "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:signature_constants", diff --git a/tensorflow/python/distribute/integration_test/saved_model_test.py b/tensorflow/python/distribute/integration_test/saved_model_test.py index 8496f3c90bb..147fc81726b 100644 --- a/tensorflow/python/distribute/integration_test/saved_model_test.py +++ b/tensorflow/python/distribute/integration_test/saved_model_test.py @@ -611,6 +611,26 @@ class PSStrategySaveAndLoadTest(test.TestCase): # ShardedVariable loading only works in v1. self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6]) + with self.assertRaisesWithLiteralMatch( + ValueError, "Loading `ShardedVariable` is not supported"): + with strategy.scope(): + tf.saved_model.load(model_dir) + + with self.assertRaisesWithLiteralMatch( + ValueError, "Loading `ShardedVariable` is not supported"): + tf.saved_model.load(model_dir) + + def test_load_with_partitioner_raises_error(self): + model = self.Model() + model_dir = self.get_temp_dir() + tf.saved_model.save(model, model_dir) + + strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( + self.cluster_resolver, tf1.fixed_size_partitioner(2)) + with self.assertRaisesRegex(ValueError, "`variable_partitioner`"): + with strategy.scope(): + tf.saved_model.load(model_dir) + if __name__ == "__main__": # TODO(b/172304955): enable logical devices. diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index 01a7c307cdd..c3e1d3ff8b1 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -560,7 +560,13 @@ class ParameterServerStrategyV2Extended( name = kwargs.get("name", None) initial_value = kwargs.get("initial_value", None) if initial_value is None: - raise ValueError("initial_value must be specified.") + raise ValueError( + "It looks like you are using `ParameterServerStrategy` with a " + "`variable_partitioner`, and trying to create a variable without " + "specifying `initial_value`. This is not allowed. Please specify the " + "`initial_value`. This can also happen if you are trying to load a " + "saved_model within a `ParameterServerStrategy` scope. Loading a " + "saved_model with `variable_partitioner` is not supported.") # Two cases where initial_value can be a callable: # 1. initial_value is passed as a callable, e.g, an `initializer` class. diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py index 553d82e4a26..5b56af79d92 100644 --- a/tensorflow/python/distribute/sharded_variable.py +++ b/tensorflow/python/distribute/sharded_variable.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.saved_model import revived_types from tensorflow.python.saved_model import save_context from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base as trackable @@ -500,3 +501,21 @@ def embedding_lookup(params, return embedding_ops.embedding_lookup(params.variables, ids, partition_strategy, name, validate_indices, max_norm) + + +def _raise_when_load(_): + # We don't have serialization and deserialization mechanisms for + # `ShardedVariable` in 2.x style save/load yet. + raise ValueError('Loading `ShardedVariable` is not supported') + + +revived_types.register_revived_type( + '_tf_distribute_sharded_variable', + lambda obj: isinstance(obj, ShardedVariable), + versions=[ + revived_types.VersionedTypeRegistration( + object_factory=_raise_when_load, + version=0, + min_producer_version=0, + min_consumer_version=0) + ]) diff --git a/tensorflow/python/distribute/sharded_variable_test.py b/tensorflow/python/distribute/sharded_variable_test.py index a020a85de2d..822012bdc69 100644 --- a/tensorflow/python/distribute/sharded_variable_test.py +++ b/tensorflow/python/distribute/sharded_variable_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test +from tensorflow.python.saved_model import load from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import save from tensorflow.python.saved_model import signature_constants @@ -300,6 +301,19 @@ class ShardedVariableTest(test.TestCase): # Continue using root.train for training self.assertAllEqual([3., 2.], root.train([0, 1]).numpy()) + def test_load_raises_error(self): + root = tracking.AutoTrackable() + v1 = variables_lib.Variable([3.]) + v2 = variables_lib.Variable([2.]) + root.v = sharded_variable.ShardedVariable([v1, v2]) + + save_dir = os.path.join(self.get_temp_dir(), 'saved_model') + save.save(root, save_dir) + + with self.assertRaisesWithLiteralMatch( + ValueError, 'Loading `ShardedVariable` is not supported'): + load.load(save_dir) + def test_validation_errors(self): with self.assertRaisesRegex(ValueError, 'Expected a list of '): sharded_variable.ShardedVariable(