Raise meaningful error message when loading a ShardedVariable.
PiperOrigin-RevId: 348539354 Change-Id: I2c4a8466c3d1355ec8e5984ed039194c18c4305c
This commit is contained in:
parent
63b8cdcb82
commit
12c67c0d47
@ -272,6 +272,7 @@ py_library(
|
|||||||
":device_util",
|
":device_util",
|
||||||
":distribute_lib",
|
":distribute_lib",
|
||||||
":reduce_util",
|
":reduce_util",
|
||||||
|
":sharded_variable",
|
||||||
":shared_variable_creator",
|
":shared_variable_creator",
|
||||||
":tpu_values",
|
":tpu_values",
|
||||||
":values",
|
":values",
|
||||||
@ -1118,6 +1119,7 @@ tf_py_test(
|
|||||||
"//tensorflow/python/compat:v2_compat",
|
"//tensorflow/python/compat:v2_compat",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/module",
|
"//tensorflow/python/module",
|
||||||
|
"//tensorflow/python/saved_model:load",
|
||||||
"//tensorflow/python/saved_model:loader",
|
"//tensorflow/python/saved_model:loader",
|
||||||
"//tensorflow/python/saved_model:save",
|
"//tensorflow/python/saved_model:save",
|
||||||
"//tensorflow/python/saved_model:signature_constants",
|
"//tensorflow/python/saved_model:signature_constants",
|
||||||
|
@ -611,6 +611,26 @@ class PSStrategySaveAndLoadTest(test.TestCase):
|
|||||||
# ShardedVariable loading only works in v1.
|
# ShardedVariable loading only works in v1.
|
||||||
self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6])
|
self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6])
|
||||||
|
|
||||||
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
ValueError, "Loading `ShardedVariable` is not supported"):
|
||||||
|
with strategy.scope():
|
||||||
|
tf.saved_model.load(model_dir)
|
||||||
|
|
||||||
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
ValueError, "Loading `ShardedVariable` is not supported"):
|
||||||
|
tf.saved_model.load(model_dir)
|
||||||
|
|
||||||
|
def test_load_with_partitioner_raises_error(self):
|
||||||
|
model = self.Model()
|
||||||
|
model_dir = self.get_temp_dir()
|
||||||
|
tf.saved_model.save(model, model_dir)
|
||||||
|
|
||||||
|
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||||
|
self.cluster_resolver, tf1.fixed_size_partitioner(2))
|
||||||
|
with self.assertRaisesRegex(ValueError, "`variable_partitioner`"):
|
||||||
|
with strategy.scope():
|
||||||
|
tf.saved_model.load(model_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# TODO(b/172304955): enable logical devices.
|
# TODO(b/172304955): enable logical devices.
|
||||||
|
@ -560,7 +560,13 @@ class ParameterServerStrategyV2Extended(
|
|||||||
name = kwargs.get("name", None)
|
name = kwargs.get("name", None)
|
||||||
initial_value = kwargs.get("initial_value", None)
|
initial_value = kwargs.get("initial_value", None)
|
||||||
if initial_value is None:
|
if initial_value is None:
|
||||||
raise ValueError("initial_value must be specified.")
|
raise ValueError(
|
||||||
|
"It looks like you are using `ParameterServerStrategy` with a "
|
||||||
|
"`variable_partitioner`, and trying to create a variable without "
|
||||||
|
"specifying `initial_value`. This is not allowed. Please specify the "
|
||||||
|
"`initial_value`. This can also happen if you are trying to load a "
|
||||||
|
"saved_model within a `ParameterServerStrategy` scope. Loading a "
|
||||||
|
"saved_model with `variable_partitioner` is not supported.")
|
||||||
|
|
||||||
# Two cases where initial_value can be a callable:
|
# Two cases where initial_value can be a callable:
|
||||||
# 1. initial_value is passed as a callable, e.g, an `initializer` class.
|
# 1. initial_value is passed as a callable, e.g, an `initializer` class.
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.ops import embedding_ops
|
|||||||
from tensorflow.python.ops import partitioned_variables
|
from tensorflow.python.ops import partitioned_variables
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
|
from tensorflow.python.saved_model import revived_types
|
||||||
from tensorflow.python.saved_model import save_context
|
from tensorflow.python.saved_model import save_context
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
@ -500,3 +501,21 @@ def embedding_lookup(params,
|
|||||||
return embedding_ops.embedding_lookup(params.variables, ids,
|
return embedding_ops.embedding_lookup(params.variables, ids,
|
||||||
partition_strategy, name,
|
partition_strategy, name,
|
||||||
validate_indices, max_norm)
|
validate_indices, max_norm)
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_when_load(_):
|
||||||
|
# We don't have serialization and deserialization mechanisms for
|
||||||
|
# `ShardedVariable` in 2.x style save/load yet.
|
||||||
|
raise ValueError('Loading `ShardedVariable` is not supported')
|
||||||
|
|
||||||
|
|
||||||
|
revived_types.register_revived_type(
|
||||||
|
'_tf_distribute_sharded_variable',
|
||||||
|
lambda obj: isinstance(obj, ShardedVariable),
|
||||||
|
versions=[
|
||||||
|
revived_types.VersionedTypeRegistration(
|
||||||
|
object_factory=_raise_when_load,
|
||||||
|
version=0,
|
||||||
|
min_producer_version=0,
|
||||||
|
min_consumer_version=0)
|
||||||
|
])
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.saved_model import load
|
||||||
from tensorflow.python.saved_model import loader
|
from tensorflow.python.saved_model import loader
|
||||||
from tensorflow.python.saved_model import save
|
from tensorflow.python.saved_model import save
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
@ -300,6 +301,19 @@ class ShardedVariableTest(test.TestCase):
|
|||||||
# Continue using root.train for training
|
# Continue using root.train for training
|
||||||
self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
|
self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
|
||||||
|
|
||||||
|
def test_load_raises_error(self):
|
||||||
|
root = tracking.AutoTrackable()
|
||||||
|
v1 = variables_lib.Variable([3.])
|
||||||
|
v2 = variables_lib.Variable([2.])
|
||||||
|
root.v = sharded_variable.ShardedVariable([v1, v2])
|
||||||
|
|
||||||
|
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||||
|
save.save(root, save_dir)
|
||||||
|
|
||||||
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
ValueError, 'Loading `ShardedVariable` is not supported'):
|
||||||
|
load.load(save_dir)
|
||||||
|
|
||||||
def test_validation_errors(self):
|
def test_validation_errors(self):
|
||||||
with self.assertRaisesRegex(ValueError, 'Expected a list of '):
|
with self.assertRaisesRegex(ValueError, 'Expected a list of '):
|
||||||
sharded_variable.ShardedVariable(
|
sharded_variable.ShardedVariable(
|
||||||
|
Loading…
Reference in New Issue
Block a user