Error out when saving IndexLookup layer.
PiperOrigin-RevId: 296093473 Change-Id: I96ce0319a1480399c47d10ecc048007483eca595
This commit is contained in:
parent
14d78c5450
commit
34a38afdee
@ -35,6 +35,7 @@ from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.layers.preprocessing import index_lookup
|
||||
from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
|
||||
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
|
||||
from tensorflow.python.keras.saving import save
|
||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -453,6 +454,19 @@ class IndexLookupSaveableTest(keras_parameterized.TestCase,
|
||||
weights = model.get_weights()
|
||||
model.set_weights(weights)
|
||||
|
||||
def test_layer_saving_with_h5(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
|
||||
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
|
||||
layer = get_layer_class()(max_tokens=10)
|
||||
layer.set_vocabulary(vocab_data)
|
||||
int_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||
path = os.path.join(self.get_temp_dir(), "model")
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"Save or restore weights that is not.*"):
|
||||
save.save_model(model, path, save_format="h5")
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class IndexLookupErrorTest(keras_parameterized.TestCase,
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.keras.saving import model_config as model_config_lib
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.utils import conv_utils
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
@ -851,22 +852,28 @@ def load_attributes_from_hdf5_group(group, name):
|
||||
return data
|
||||
|
||||
|
||||
def _legacy_weights(model):
|
||||
def _legacy_weights(layer):
|
||||
"""DO NOT USE.
|
||||
|
||||
For legacy reason, the model.weights was in the order of
|
||||
For legacy reason, the layer.weights was in the order of
|
||||
[self.trainable_weights + self.non_trainable_weights], and this order was
|
||||
used for preserving the weights in h5 format. The new order of model.weights
|
||||
are the same as model.get_weights() which is more intuitive for user. To
|
||||
used for preserving the weights in h5 format. The new order of layer.weights
|
||||
are the same as layer.get_weights() which is more intuitive for user. To
|
||||
keep supporting the existing saved h5 file, this method should be used to
|
||||
save/load weights. In future version, we will delete this method and
|
||||
introduce a breaking change for h5 and stay with the new order for weights.
|
||||
|
||||
Args:
|
||||
model: a model or layer instance.
|
||||
layer: a `tf.keras.Model` or `tf.keras.layers.Layer` instance.
|
||||
|
||||
Returns:
|
||||
A list of variables with the order of trainable_weights, followed by
|
||||
non_trainable_weights.
|
||||
"""
|
||||
return model.trainable_weights + model.non_trainable_weights
|
||||
weights = layer.trainable_weights + layer.non_trainable_weights
|
||||
if any([not isinstance(w, variables_module.Variable) for w in weights]):
|
||||
raise NotImplementedError(
|
||||
'Save or restore weights that is not an instance of `tf.Variable` is '
|
||||
'not supported in h5, use `save_format=\'tf\'` instead. Got a model '
|
||||
'or layer {} with weights {}'.format(layer.__class__.__name__, weights))
|
||||
return weights
|
||||
|
Loading…
Reference in New Issue
Block a user