Move tf.Keras object identifiers into constants.py.
PiperOrigin-RevId: 345257730 Change-Id: I0961d32f28300fe06cac6fa16928fec119baf779
This commit is contained in:
parent
753786571e
commit
a39726f0f4
@ -22,9 +22,9 @@ message SavedObject {
|
|||||||
string node_path = 3;
|
string node_path = 3;
|
||||||
|
|
||||||
// Identifier to determine loading function.
|
// Identifier to determine loading function.
|
||||||
// Currently supported identifiers:
|
// Must be one of:
|
||||||
// _tf_keras_layer, _tf_keras_input_layer, _tf_keras_rnn_layer,
|
// _tf_keras_input_layer, _tf_keras_layer, _tf_keras_metric,
|
||||||
// _tf_keras_metric, _tf_keras_network, _tf_keras_model,
|
// _tf_keras_model, _tf_keras_network, _tf_keras_rnn_layer,
|
||||||
// _tf_keras_sequential
|
// _tf_keras_sequential
|
||||||
string identifier = 4;
|
string identifier = 4;
|
||||||
// Metadata containing a JSON-serialized object with the non-TensorFlow
|
// Metadata containing a JSON-serialized object with the non-TensorFlow
|
||||||
|
@ -30,3 +30,22 @@ KERAS_CACHE_KEY = 'keras_serialized_attributes'
|
|||||||
|
|
||||||
# Name of Keras metadata file stored in the SavedModel.
|
# Name of Keras metadata file stored in the SavedModel.
|
||||||
SAVED_METADATA_PATH = 'keras_metadata.pb'
|
SAVED_METADATA_PATH = 'keras_metadata.pb'
|
||||||
|
|
||||||
|
# Names of SavedObject Keras identifiers.
|
||||||
|
INPUT_LAYER_IDENTIFIER = '_tf_keras_input_layer'
|
||||||
|
LAYER_IDENTIFIER = '_tf_keras_layer'
|
||||||
|
METRIC_IDENTIFIER = '_tf_keras_metric'
|
||||||
|
MODEL_IDENTIFIER = '_tf_keras_model'
|
||||||
|
NETWORK_IDENTIFIER = '_tf_keras_network'
|
||||||
|
RNN_LAYER_IDENTIFIER = '_tf_keras_rnn_layer'
|
||||||
|
SEQUENTIAL_IDENTIFIER = '_tf_keras_sequential'
|
||||||
|
|
||||||
|
KERAS_OBJECT_IDENTIFIERS = (
|
||||||
|
INPUT_LAYER_IDENTIFIER,
|
||||||
|
LAYER_IDENTIFIER,
|
||||||
|
METRIC_IDENTIFIER,
|
||||||
|
MODEL_IDENTIFIER,
|
||||||
|
NETWORK_IDENTIFIER,
|
||||||
|
RNN_LAYER_IDENTIFIER,
|
||||||
|
SEQUENTIAL_IDENTIFIER,
|
||||||
|
)
|
||||||
|
@ -33,7 +33,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def object_identifier(self):
|
def object_identifier(self):
|
||||||
return '_tf_keras_layer'
|
return constants.LAYER_IDENTIFIER
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def python_properties(self):
|
def python_properties(self):
|
||||||
@ -127,7 +127,7 @@ class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def object_identifier(self):
|
def object_identifier(self):
|
||||||
return '_tf_keras_input_layer'
|
return constants.INPUT_LAYER_IDENTIFIER
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def python_properties(self):
|
def python_properties(self):
|
||||||
@ -153,7 +153,7 @@ class RNNSavedModelSaver(LayerSavedModelSaver):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def object_identifier(self):
|
def object_identifier(self):
|
||||||
return '_tf_keras_rnn_layer'
|
return constants.RNN_LAYER_IDENTIFIER
|
||||||
|
|
||||||
def _get_serialized_attributes_internal(self, serialization_cache):
|
def _get_serialized_attributes_internal(self, serialization_cache):
|
||||||
objects, functions = (
|
objects, functions = (
|
||||||
|
@ -92,12 +92,6 @@ PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
|
|||||||
PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
|
PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
|
||||||
|
|
||||||
|
|
||||||
KERAS_OBJECT_IDENTIFIERS = (
|
|
||||||
'_tf_keras_layer', '_tf_keras_input_layer', '_tf_keras_network',
|
|
||||||
'_tf_keras_model', '_tf_keras_sequential', '_tf_keras_metric',
|
|
||||||
'_tf_keras_rnn_layer')
|
|
||||||
|
|
||||||
|
|
||||||
def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
|
def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
|
||||||
"""Loads Keras objects from a SavedModel.
|
"""Loads Keras objects from a SavedModel.
|
||||||
|
|
||||||
@ -196,7 +190,7 @@ def _read_legacy_metadata(object_graph_def, metadata):
|
|||||||
node_paths = _generate_object_paths(object_graph_def)
|
node_paths = _generate_object_paths(object_graph_def)
|
||||||
for node_id, proto in enumerate(object_graph_def.nodes):
|
for node_id, proto in enumerate(object_graph_def.nodes):
|
||||||
if (proto.WhichOneof('kind') == 'user_object' and
|
if (proto.WhichOneof('kind') == 'user_object' and
|
||||||
proto.user_object.identifier in KERAS_OBJECT_IDENTIFIERS):
|
proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
|
||||||
metadata.nodes.add(
|
metadata.nodes.add(
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_path=node_paths[node_id],
|
node_path=node_paths[node_id],
|
||||||
@ -347,7 +341,7 @@ class KerasObjectLoader(object):
|
|||||||
if (child_proto.user_object.identifier in
|
if (child_proto.user_object.identifier in
|
||||||
revived_types.registered_identifiers()):
|
revived_types.registered_identifiers()):
|
||||||
setter = revived_types.get_setter(child_proto.user_object)
|
setter = revived_types.get_setter(child_proto.user_object)
|
||||||
elif obj_child._object_identifier in KERAS_OBJECT_IDENTIFIERS:
|
elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
|
||||||
setter = _revive_setter
|
setter = _revive_setter
|
||||||
else:
|
else:
|
||||||
setter = setattr
|
setter = setattr
|
||||||
@ -384,7 +378,7 @@ class KerasObjectLoader(object):
|
|||||||
# time by creating objects multiple times).
|
# time by creating objects multiple times).
|
||||||
metric_list = []
|
metric_list = []
|
||||||
for node_metadata in self._metadata.nodes:
|
for node_metadata in self._metadata.nodes:
|
||||||
if node_metadata.identifier == '_tf_keras_metric':
|
if node_metadata.identifier == constants.METRIC_IDENTIFIER:
|
||||||
metric_list.append(node_metadata)
|
metric_list.append(node_metadata)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -432,7 +426,7 @@ class KerasObjectLoader(object):
|
|||||||
|
|
||||||
def _revive_from_config(self, identifier, metadata, node_id):
|
def _revive_from_config(self, identifier, metadata, node_id):
|
||||||
"""Revives a layer/model from config, or returns None."""
|
"""Revives a layer/model from config, or returns None."""
|
||||||
if identifier == '_tf_keras_metric':
|
if identifier == constants.METRIC_IDENTIFIER:
|
||||||
obj = self._revive_metric_from_config(metadata)
|
obj = self._revive_metric_from_config(metadata)
|
||||||
else:
|
else:
|
||||||
obj = (
|
obj = (
|
||||||
@ -921,11 +915,12 @@ def revive_custom_object(identifier, metadata):
|
|||||||
model_class = training_lib_v1.Model
|
model_class = training_lib_v1.Model
|
||||||
|
|
||||||
revived_classes = {
|
revived_classes = {
|
||||||
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
|
constants.INPUT_LAYER_IDENTIFIER: (
|
||||||
'_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer),
|
RevivedInputLayer, input_layer.InputLayer),
|
||||||
'_tf_keras_network': (RevivedNetwork, functional_lib.Functional),
|
constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer),
|
||||||
'_tf_keras_model': (RevivedNetwork, model_class),
|
constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class),
|
||||||
'_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential),
|
constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional),
|
||||||
|
constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential),
|
||||||
}
|
}
|
||||||
parent_classes = revived_classes.get(identifier, None)
|
parent_classes = revived_classes.get(identifier, None)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.keras.saving.saved_model import constants
|
||||||
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.training.tracking import data_structures
|
from tensorflow.python.training.tracking import data_structures
|
||||||
@ -28,7 +29,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def object_identifier(self):
|
def object_identifier(self):
|
||||||
return '_tf_keras_metric'
|
return constants.METRIC_IDENTIFIER
|
||||||
|
|
||||||
def _python_properties_internal(self):
|
def _python_properties_internal(self):
|
||||||
metadata = dict(
|
metadata = dict(
|
||||||
|
@ -29,7 +29,7 @@ class ModelSavedModelSaver(layer_serialization.LayerSavedModelSaver):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def object_identifier(self):
|
def object_identifier(self):
|
||||||
return '_tf_keras_model'
|
return constants.MODEL_IDENTIFIER
|
||||||
|
|
||||||
def _python_properties_internal(self):
|
def _python_properties_internal(self):
|
||||||
metadata = super(ModelSavedModelSaver, self)._python_properties_internal()
|
metadata = super(ModelSavedModelSaver, self)._python_properties_internal()
|
||||||
@ -63,4 +63,4 @@ class SequentialSavedModelSaver(ModelSavedModelSaver):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def object_identifier(self):
|
def object_identifier(self):
|
||||||
return '_tf_keras_sequential'
|
return constants.SEQUENTIAL_IDENTIFIER
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.keras.saving.saved_model import constants
|
||||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||||
|
|
||||||
|
|
||||||
@ -27,4 +28,4 @@ class NetworkSavedModelSaver(model_serialization.ModelSavedModelSaver):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def object_identifier(self):
|
def object_identifier(self):
|
||||||
return '_tf_keras_network'
|
return constants.NETWORK_IDENTIFIER
|
||||||
|
Loading…
Reference in New Issue
Block a user