Error out when saving IndexLookup layer.

PiperOrigin-RevId: 296093473
Change-Id: I96ce0319a1480399c47d10ecc048007483eca595
This commit is contained in:
Zhenyu Tan 2020-02-19 17:15:40 -08:00 committed by TensorFlower Gardener
parent 14d78c5450
commit 34a38afdee
2 changed files with 27 additions and 6 deletions

View File

@ -35,6 +35,7 @@ from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers.preprocessing import index_lookup
from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.keras.saving import save
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
@ -453,6 +454,19 @@ class IndexLookupSaveableTest(keras_parameterized.TestCase,
weights = model.get_weights()
model.set_weights(weights)
def test_layer_saving_with_h5(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
layer = get_layer_class()(max_tokens=10)
layer.set_vocabulary(vocab_data)
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
path = os.path.join(self.get_temp_dir(), "model")
with self.assertRaisesRegex(NotImplementedError,
"Save or restore weights that is not.*"):
save.save_model(model, path, save_format="h5")
@keras_parameterized.run_all_keras_modes
class IndexLookupErrorTest(keras_parameterized.TestCase,

View File

@ -31,6 +31,7 @@ from tensorflow.python.keras.saving import model_config as model_config_lib
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import serialization
@ -851,22 +852,28 @@ def load_attributes_from_hdf5_group(group, name):
return data
def _legacy_weights(model):
def _legacy_weights(layer):
"""DO NOT USE.
For legacy reason, the model.weights was in the order of
For legacy reason, the layer.weights was in the order of
[self.trainable_weights + self.non_trainable_weights], and this order was
used for preserving the weights in h5 format. The new order of model.weights
are the same as model.get_weights() which is more intuitive for user. To
used for preserving the weights in h5 format. The new order of layer.weights
are the same as layer.get_weights() which is more intuitive for user. To
keep supporting the existing saved h5 file, this method should be used to
save/load weights. In future version, we will delete this method and
introduce a breaking change for h5 and stay with the new order for weights.
Args:
model: a model or layer instance.
layer: a `tf.keras.Model` or `tf.keras.layers.Layer` instance.
Returns:
A list of variables with the order of trainable_weights, followed by
non_trainable_weights.
"""
return model.trainable_weights + model.non_trainable_weights
weights = layer.trainable_weights + layer.non_trainable_weights
if any([not isinstance(w, variables_module.Variable) for w in weights]):
raise NotImplementedError(
'Save or restore weights that is not an instance of `tf.Variable` is '
'not supported in h5, use `save_format=\'tf\'` instead. Got a model '
'or layer {} with weights {}'.format(layer.__class__.__name__, weights))
return weights