Raise meaningful error message when loading a ShardedVariable.

PiperOrigin-RevId: 348539354
Change-Id: I2c4a8466c3d1355ec8e5984ed039194c18c4305c
This commit is contained in:
Chenkai Kuang 2020-12-21 15:38:56 -08:00 committed by TensorFlower Gardener
parent 63b8cdcb82
commit 12c67c0d47
5 changed files with 62 additions and 1 deletions

View File

@ -272,6 +272,7 @@ py_library(
":device_util", ":device_util",
":distribute_lib", ":distribute_lib",
":reduce_util", ":reduce_util",
":sharded_variable",
":shared_variable_creator", ":shared_variable_creator",
":tpu_values", ":tpu_values",
":values", ":values",
@ -1118,6 +1119,7 @@ tf_py_test(
"//tensorflow/python/compat:v2_compat", "//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/module", "//tensorflow/python/module",
"//tensorflow/python/saved_model:load",
"//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:loader",
"//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:save",
"//tensorflow/python/saved_model:signature_constants", "//tensorflow/python/saved_model:signature_constants",

View File

@ -611,6 +611,26 @@ class PSStrategySaveAndLoadTest(test.TestCase):
# ShardedVariable loading only works in v1. # ShardedVariable loading only works in v1.
self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6]) self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6])
with self.assertRaisesWithLiteralMatch(
ValueError, "Loading `ShardedVariable` is not supported"):
with strategy.scope():
tf.saved_model.load(model_dir)
with self.assertRaisesWithLiteralMatch(
ValueError, "Loading `ShardedVariable` is not supported"):
tf.saved_model.load(model_dir)
def test_load_with_partitioner_raises_error(self):
model = self.Model()
model_dir = self.get_temp_dir()
tf.saved_model.save(model, model_dir)
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
self.cluster_resolver, tf1.fixed_size_partitioner(2))
with self.assertRaisesRegex(ValueError, "`variable_partitioner`"):
with strategy.scope():
tf.saved_model.load(model_dir)
if __name__ == "__main__": if __name__ == "__main__":
# TODO(b/172304955): enable logical devices. # TODO(b/172304955): enable logical devices.

View File

@ -560,7 +560,13 @@ class ParameterServerStrategyV2Extended(
name = kwargs.get("name", None) name = kwargs.get("name", None)
initial_value = kwargs.get("initial_value", None) initial_value = kwargs.get("initial_value", None)
if initial_value is None: if initial_value is None:
raise ValueError("initial_value must be specified.") raise ValueError(
"It looks like you are using `ParameterServerStrategy` with a "
"`variable_partitioner`, and trying to create a variable without "
"specifying `initial_value`. This is not allowed. Please specify the "
"`initial_value`. This can also happen if you are trying to load a "
"saved_model within a `ParameterServerStrategy` scope. Loading a "
"saved_model with `variable_partitioner` is not supported.")
# Two cases where initial_value can be a callable: # Two cases where initial_value can be a callable:
# 1. initial_value is passed as a callable, e.g, an `initializer` class. # 1. initial_value is passed as a callable, e.g, an `initializer` class.

View File

@ -28,6 +28,7 @@ from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as variables_lib from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import save_context from tensorflow.python.saved_model import save_context
from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
@ -500,3 +501,21 @@ def embedding_lookup(params,
return embedding_ops.embedding_lookup(params.variables, ids, return embedding_ops.embedding_lookup(params.variables, ids,
partition_strategy, name, partition_strategy, name,
validate_indices, max_norm) validate_indices, max_norm)
def _raise_when_load(_):
# We don't have serialization and deserialization mechanisms for
# `ShardedVariable` in 2.x style save/load yet.
raise ValueError('Loading `ShardedVariable` is not supported')
revived_types.register_revived_type(
'_tf_distribute_sharded_variable',
lambda obj: isinstance(obj, ShardedVariable),
versions=[
revived_types.VersionedTypeRegistration(
object_factory=_raise_when_load,
version=0,
min_producer_version=0,
min_consumer_version=0)
])

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import variables as variables_lib from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import save from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
@ -300,6 +301,19 @@ class ShardedVariableTest(test.TestCase):
# Continue using root.train for training # Continue using root.train for training
self.assertAllEqual([3., 2.], root.train([0, 1]).numpy()) self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
def test_load_raises_error(self):
root = tracking.AutoTrackable()
v1 = variables_lib.Variable([3.])
v2 = variables_lib.Variable([2.])
root.v = sharded_variable.ShardedVariable([v1, v2])
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save.save(root, save_dir)
with self.assertRaisesWithLiteralMatch(
ValueError, 'Loading `ShardedVariable` is not supported'):
load.load(save_dir)
def test_validation_errors(self): def test_validation_errors(self):
with self.assertRaisesRegex(ValueError, 'Expected a list of '): with self.assertRaisesRegex(ValueError, 'Expected a list of '):
sharded_variable.ShardedVariable( sharded_variable.ShardedVariable(