Raise meaningful error message when loading a ShardedVariable.
PiperOrigin-RevId: 348539354 Change-Id: I2c4a8466c3d1355ec8e5984ed039194c18c4305c
This commit is contained in:
parent
63b8cdcb82
commit
12c67c0d47
@ -272,6 +272,7 @@ py_library(
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":reduce_util",
|
||||
":sharded_variable",
|
||||
":shared_variable_creator",
|
||||
":tpu_values",
|
||||
":values",
|
||||
@ -1118,6 +1119,7 @@ tf_py_test(
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/saved_model:load",
|
||||
"//tensorflow/python/saved_model:loader",
|
||||
"//tensorflow/python/saved_model:save",
|
||||
"//tensorflow/python/saved_model:signature_constants",
|
||||
|
@ -611,6 +611,26 @@ class PSStrategySaveAndLoadTest(test.TestCase):
|
||||
# ShardedVariable loading only works in v1.
|
||||
self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6])
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "Loading `ShardedVariable` is not supported"):
|
||||
with strategy.scope():
|
||||
tf.saved_model.load(model_dir)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "Loading `ShardedVariable` is not supported"):
|
||||
tf.saved_model.load(model_dir)
|
||||
|
||||
def test_load_with_partitioner_raises_error(self):
|
||||
model = self.Model()
|
||||
model_dir = self.get_temp_dir()
|
||||
tf.saved_model.save(model, model_dir)
|
||||
|
||||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||
self.cluster_resolver, tf1.fixed_size_partitioner(2))
|
||||
with self.assertRaisesRegex(ValueError, "`variable_partitioner`"):
|
||||
with strategy.scope():
|
||||
tf.saved_model.load(model_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(b/172304955): enable logical devices.
|
||||
|
@ -560,7 +560,13 @@ class ParameterServerStrategyV2Extended(
|
||||
name = kwargs.get("name", None)
|
||||
initial_value = kwargs.get("initial_value", None)
|
||||
if initial_value is None:
|
||||
raise ValueError("initial_value must be specified.")
|
||||
raise ValueError(
|
||||
"It looks like you are using `ParameterServerStrategy` with a "
|
||||
"`variable_partitioner`, and trying to create a variable without "
|
||||
"specifying `initial_value`. This is not allowed. Please specify the "
|
||||
"`initial_value`. This can also happen if you are trying to load a "
|
||||
"saved_model within a `ParameterServerStrategy` scope. Loading a "
|
||||
"saved_model with `variable_partitioner` is not supported.")
|
||||
|
||||
# Two cases where initial_value can be a callable:
|
||||
# 1. initial_value is passed as a callable, e.g, an `initializer` class.
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.saved_model import save_context
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
@ -500,3 +501,21 @@ def embedding_lookup(params,
|
||||
return embedding_ops.embedding_lookup(params.variables, ids,
|
||||
partition_strategy, name,
|
||||
validate_indices, max_norm)
|
||||
|
||||
|
||||
def _raise_when_load(_):
|
||||
# We don't have serialization and deserialization mechanisms for
|
||||
# `ShardedVariable` in 2.x style save/load yet.
|
||||
raise ValueError('Loading `ShardedVariable` is not supported')
|
||||
|
||||
|
||||
revived_types.register_revived_type(
|
||||
'_tf_distribute_sharded_variable',
|
||||
lambda obj: isinstance(obj, ShardedVariable),
|
||||
versions=[
|
||||
revived_types.VersionedTypeRegistration(
|
||||
object_factory=_raise_when_load,
|
||||
version=0,
|
||||
min_producer_version=0,
|
||||
min_consumer_version=0)
|
||||
])
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
@ -300,6 +301,19 @@ class ShardedVariableTest(test.TestCase):
|
||||
# Continue using root.train for training
|
||||
self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
|
||||
|
||||
def test_load_raises_error(self):
|
||||
root = tracking.AutoTrackable()
|
||||
v1 = variables_lib.Variable([3.])
|
||||
v2 = variables_lib.Variable([2.])
|
||||
root.v = sharded_variable.ShardedVariable([v1, v2])
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
save.save(root, save_dir)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, 'Loading `ShardedVariable` is not supported'):
|
||||
load.load(save_dir)
|
||||
|
||||
def test_validation_errors(self):
|
||||
with self.assertRaisesRegex(ValueError, 'Expected a list of '):
|
||||
sharded_variable.ShardedVariable(
|
||||
|
Loading…
Reference in New Issue
Block a user