From 0225ad8ca1749219635dd780b9a2507957bfeac6 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Mon, 24 Feb 2020 16:18:48 -0800 Subject: [PATCH] Save build input shape to SavedModel metadata, and use the shape information when loading layers. This CL also adds a custom JSON encoder/decoder to handle the build_input_shape, which can be a TensorShape or tuple. As with cl/293453611, this resolves loading built-in preprocessing layers but does not address custom preprocessing layers. PiperOrigin-RevId: 296997111 Change-Id: Ic5831471f0d823a9eed7a00b28f6a8a8f9b991b5 --- tensorflow/python/keras/engine/base_layer.py | 1 + .../python/keras/engine/base_layer_v1.py | 5 + tensorflow/python/keras/saving/BUILD | 14 ++ .../saving/saved_model/base_serialization.py | 8 +- .../keras/saving/saved_model/json_utils.py | 69 +++++++++ .../saving/saved_model/json_utils_test.py | 55 ++++++++ .../saving/saved_model/layer_serialization.py | 2 + .../python/keras/saving/saved_model/load.py | 133 ++++++++++++------ .../keras/saving/saved_model/revive_test.py | 13 +- .../python/keras/utils/generic_utils.py | 9 +- .../python/saved_model/revived_types.py | 10 ++ 11 files changed, 265 insertions(+), 54 deletions(-) create mode 100644 tensorflow/python/keras/saving/saved_model/json_utils.py create mode 100644 tensorflow/python/keras/saving/saved_model/json_utils_test.py diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index d6977da1382..0e09bb291c5 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -383,6 +383,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._auto_track_sub_layers = True @trackable.no_automatic_dependency_tracking + @base_layer_utils.default def build(self, input_shape): """Creates the variables of the layer (optional, for subclass implementers). diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 82692932731..60ee17d76d5 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -178,6 +178,7 @@ class Layer(base_layer.Layer): # Indicates whether `build` needs to be called upon layer call, to create # the layer's weights. self.built = False + self._build_input_shape = None # Provides information about which inputs are compatible with the layer. self._input_spec = None self.supports_masking = False @@ -252,6 +253,8 @@ class Layer(base_layer.Layer): # might want to turn it off, like Sequential model. self._auto_track_sub_layers = True + @trackable.no_automatic_dependency_tracking + @base_layer_utils.default def build(self, input_shape): """Creates the variables of the layer (optional, for subclass implementers). @@ -266,6 +269,8 @@ class Layer(base_layer.Layer): `TensorShape` if the layer expects a list of inputs (one instance per input). """ + if not hasattr(self.build, '_is_default'): + self._build_input_shape = input_shape self.built = True @doc_controls.for_subclass_implementers diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD index 3a4bca18e40..7ab6639d118 100644 --- a/tensorflow/python/keras/saving/BUILD +++ b/tensorflow/python/keras/saving/BUILD @@ -19,6 +19,7 @@ py_library( "save.py", "saved_model/base_serialization.py", "saved_model/constants.py", + "saved_model/json_utils.py", "saved_model/layer_serialization.py", "saved_model/load.py", "saved_model/model_serialization.py", @@ -164,6 +165,7 @@ tf_py_test( size = "medium", srcs = ["saved_model/revive_test.py"], python_version = "PY3", + shard_count = 4, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -171,3 +173,15 @@ tf_py_test( "@absl_py//absl/testing:parameterized", ], ) + +tf_py_test( + name = "json_utils_test", + size = "small", + srcs = ["saved_model/json_utils_test.py"], + python_version = "PY3", + deps = [ + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/python/keras/saving/saved_model/base_serialization.py b/tensorflow/python/keras/saving/saved_model/base_serialization.py index 565601a5242..0065e6d786e 100644 --- a/tensorflow/python/keras/saving/saved_model/base_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/base_serialization.py @@ -19,12 +19,10 @@ from __future__ import division from __future__ import print_function import abc -import json - import six +from tensorflow.python.keras.saving.saved_model import json_utils from tensorflow.python.training.tracking import tracking -from tensorflow.python.util import serialization @six.add_metaclass(abc.ABCMeta) @@ -53,9 +51,7 @@ class SavedModelSaver(object): """ # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an # object is in the python property) - return json.dumps( - self.python_properties, - default=serialization.get_json_type) + return json_utils.Encoder().encode(self.python_properties) def list_extra_dependencies_for_serialization(self, serialization_cache): """Lists extra dependencies to serialize to SavedModel. diff --git a/tensorflow/python/keras/saving/saved_model/json_utils.py b/tensorflow/python/keras/saving/saved_model/json_utils.py new file mode 100644 index 00000000000..0ac86d4e692 --- /dev/null +++ b/tensorflow/python/keras/saving/saved_model/json_utils.py @@ -0,0 +1,69 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils for creating and loading the Layer metadata for SavedModel. + +These are required to retain the original format of the build input shape, since +layers and models may have different build behaviors depending on if the shape +is a list, tuple, or TensorShape. For example, Network.build() will create +separate inputs if the given input_shape is a list, and will create a single +input if the given shape is a tuple. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import serialization + + +class Encoder(json.JSONEncoder): + """JSON encoder and decoder that handles TensorShapes and tuples.""" + + def default(self, obj): + if isinstance(obj, tensor_shape.TensorShape): + items = obj.as_list() if obj.rank is not None else None + return {'class_name': 'TensorShape', 'items': items} + return serialization.get_json_type(obj) + + def encode(self, obj): + return super(Encoder, self).encode(_encode_tuple(obj)) + + +def _encode_tuple(x): + if isinstance(x, tuple): + return {'class_name': '__tuple__', + 'items': tuple(_encode_tuple(i) for i in x)} + elif isinstance(x, list): + return [_encode_tuple(i) for i in x] + elif isinstance(x, dict): + return {key: _encode_tuple(value) for key, value in x.items()} + else: + return x + + +def decode(json_string): + return json.loads(json_string, object_hook=_decode_helper) + + +def _decode_helper(obj): + if isinstance(obj, dict) and 'class_name' in obj: + if obj['class_name'] == 'TensorShape': + return tensor_shape.TensorShape(obj['items']) + elif obj['class_name'] == '__tuple__': + return tuple(_decode_helper(i) for i in obj['items']) + return obj diff --git a/tensorflow/python/keras/saving/saved_model/json_utils_test.py b/tensorflow/python/keras/saving/saved_model/json_utils_test.py new file mode 100644 index 00000000000..f940279404f --- /dev/null +++ b/tensorflow/python/keras/saving/saved_model/json_utils_test.py @@ -0,0 +1,55 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Tests the JSON encoder and decoder.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.saving.saved_model import json_utils +from tensorflow.python.platform import test + + +class JsonUtilsTest(test.TestCase): + + def test_encode_decode_tensor_shape(self): + metadata = { + 'key1': tensor_shape.TensorShape(None), + 'key2': [tensor_shape.TensorShape([None]), + tensor_shape.TensorShape([3, None, 5])]} + string = json_utils.Encoder().encode(metadata) + loaded = json_utils.decode(string) + + self.assertEqual(set(loaded.keys()), {'key1', 'key2'}) + self.assertAllEqual(loaded['key1'].rank, None) + self.assertAllEqual(loaded['key2'][0].as_list(), [None]) + self.assertAllEqual(loaded['key2'][1].as_list(), [3, None, 5]) + + def test_encode_decode_tuple(self): + metadata = { + 'key1': (3, 5), + 'key2': [(1, (3, 4)), (1,)]} + string = json_utils.Encoder().encode(metadata) + loaded = json_utils.decode(string) + + self.assertEqual(set(loaded.keys()), {'key1', 'key2'}) + self.assertAllEqual(loaded['key1'], (3, 5)) + self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1,)]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py index ab1edaab585..6dffcc65c7e 100644 --- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py @@ -68,6 +68,8 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver): hasattr(self.obj.activity_regularizer, 'get_config')): metadata['activity_regularizer'] = generic_utils.serialize_keras_object( self.obj.activity_regularizer) + if self.obj._build_input_shape is not None: # pylint: disable=protected-access + metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access return metadata def objects_to_serialize(self, serialization_cache): diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index d53530ec1d7..1d09ec7d150 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import json import re from tensorflow.python.eager import context @@ -29,6 +28,7 @@ from tensorflow.python.keras import regularizers from tensorflow.python.keras.engine import input_spec from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving.saved_model import constants +from tensorflow.python.keras.saving.saved_model import json_utils from tensorflow.python.keras.saving.saved_model import utils from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints from tensorflow.python.keras.utils import generic_utils @@ -40,6 +40,7 @@ from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking.tracking import delete_tracking from tensorflow.python.util import compat from tensorflow.python.util import nest +from tensorflow.python.util import object_identity from tensorflow.python.util.lazy_loader import LazyLoader # To avoid circular dependencies between keras/engine and keras/saving, @@ -164,6 +165,8 @@ class KerasObjectLoader(tf_load.Loader): # records all nodes that were generated directly/indirectly from the config, # so that they do not get recreated multiple times. self._nodes_recreated_from_config = {} + self._all_nodes_recreated_from_config = ( + object_identity.ObjectIdentityWeakSet()) # Store all node ids that have already been traversed when tracking nodes # that were recreated from the config. self._traversed_nodes_from_config = [] @@ -227,30 +230,36 @@ class KerasObjectLoader(tf_load.Loader): return self._traversed_nodes_from_config.append(node_id) obj._maybe_initialize_trackable() + for reference in proto.children: obj_child = obj._lookup_dependency(reference.local_name) - setter = setattr + child_id = reference.node_id + child_proto = self._proto.nodes[child_id] + if not isinstance(obj_child, trackable.Trackable): continue - if obj_child._object_identifier in revived_types.registered_identifiers(): - setter = lambda *unused: None + if (child_proto.user_object.identifier in + revived_types.registered_identifiers()): + setter = revived_types.get_setter(child_proto.user_object) elif obj_child._object_identifier in KERAS_OBJECT_IDENTIFIERS: - metadata = self._proto.nodes[reference.node_id].user_object.metadata setter = _revive_setter - _add_serialized_attributes(obj_child, json.loads(metadata)) + else: + setter = setattr # pylint: enable=protected-access - if (reference.node_id in self._nodes_recreated_from_config and - self._nodes_recreated_from_config[reference.node_id][0] is not - obj_child): + + if (child_id in self._nodes_recreated_from_config and + self._nodes_recreated_from_config[child_id][0] is not obj_child): # This means that the same trackable object is referenced by two # different objects that were recreated from the config. logging.warn('Looks like there is an object (perhaps variable or layer)' ' that is shared between different layers/models. This ' - 'may cause issues when training the model. Object: {}' - .format(obj_child)) - self._nodes_recreated_from_config[reference.node_id] = obj_child, setter + 'may cause issues when restoring the variable values.' + 'Object: {}'.format(obj_child)) + self._nodes_recreated_from_config[child_id] = ( + obj_child, self._config_node_setter(setter)) + self._all_nodes_recreated_from_config.add(obj_child) self._add_children_recreated_from_config( - obj_child, self._proto.nodes[reference.node_id], reference.node_id) + obj_child, child_proto, child_id) def _load_layers(self): layers = {} @@ -262,19 +271,35 @@ class KerasObjectLoader(tf_load.Loader): def _load_layer(self, proto, node_id): """Load a single layer from a SavedUserObject proto.""" + metadata = json_utils.decode(proto.metadata) + + # If node was already created + if node_id in self._nodes_recreated_from_config: + node, setter = self._nodes_recreated_from_config[node_id] + + self._try_build_layer(node, node_id, metadata.get('build_input_shape')) + + # Revive setter requires the object to have a `_serialized_attributes` + # property. Add it here. + _maybe_add_serialized_attributes(node, metadata) + + config = metadata.get('config') + if _is_graph_network(node) and generic_utils.validate_config(config): + self.model_layer_dependencies[node_id] = ( + node, self._get_child_layer_node_ids(node_id, node.name)) + return node, setter + # Detect whether this object can be revived from the config. If not, then # revive from the SavedModel instead. - metadata = json.loads(proto.metadata) obj, setter = self._revive_from_config(metadata, node_id) if obj is None: obj, setter = revive_custom_object(proto.identifier, metadata) - if setter == _revive_setter: - # Add an attribute that stores the extra functions/objects saved in the - # SavedModel. Most of these functions/objects are ignored, but some are - # used later in the loading process (e.g. the list of regularization - # losses, or the training config of compiled models). - _add_serialized_attributes(obj, metadata) + # Add an attribute that stores the extra functions/objects saved in the + # SavedModel. Most of these functions/objects are ignored, but some are + # used later in the loading process (e.g. the list of regularization + # losses, or the training config of compiled models). + _maybe_add_serialized_attributes(obj, metadata) return obj, setter def _revive_from_config(self, metadata, node_id): @@ -284,8 +309,9 @@ class KerasObjectLoader(tf_load.Loader): if obj is None: return None, None - setter = _revive_setter + setter = self._config_node_setter(_revive_setter) self._nodes_recreated_from_config[node_id] = obj, setter + self._all_nodes_recreated_from_config.add(obj) self._add_children_recreated_from_config( obj, self._proto.nodes[node_id], node_id) return obj, setter @@ -300,9 +326,8 @@ class KerasObjectLoader(tf_load.Loader): model_is_functional_or_sequential = ( metadata.get('is_graph_network', False) or metadata['class_name'] == 'Sequential') - if (config is None or - generic_utils.LAYER_UNDEFINED_CONFIG_KEY in config or - not model_is_functional_or_sequential): + if not (generic_utils.validate_config(config) and + model_is_functional_or_sequential): return None # Revive as custom model. # Revive functional and sequential models as blank model objects for now ( @@ -329,7 +354,7 @@ class KerasObjectLoader(tf_load.Loader): # found. class_name = metadata.get('class_name') config = metadata.get('config') - if config is None or generic_utils.LAYER_UNDEFINED_CONFIG_KEY in config: + if not generic_utils.validate_config(config): return None try: @@ -348,16 +373,31 @@ class KerasObjectLoader(tf_load.Loader): obj._set_dtype_policy(metadata['dtype']) # pylint: enable=protected-access - input_shape = None - if not isinstance(obj, input_layer.InputLayer): - input_shape = self._infer_inputs(node_id, convert_to_shapes=True) - if input_shape is None: - return None - obj.build(input_shape) - obj.built = True + build_input_shape = metadata.get('build_input_shape') + built = self._try_build_layer(obj, node_id, build_input_shape) + + if not built: + # If the layer cannot be built, revive a custom layer instead. + return None return obj + def _try_build_layer(self, obj, node_id, build_input_shape): + """Attempts to build the layer.""" + if obj.built or hasattr(obj.build, '_is_default'): + obj.built = True + return True + + if build_input_shape is None: + build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True) + + if build_input_shape is not None: + obj.build(build_input_shape) + base_layer.Layer.build(obj, build_input_shape) + return True + + return False + def _load_edges(self): """Add edges for all nodes that are not waiting on initialization.""" for node_id, proto in enumerate(self._proto.nodes): @@ -432,8 +472,8 @@ class KerasObjectLoader(tf_load.Loader): .format(uninitialized_model_names)) def _reconstruct_model(self, model_id, model, layers): - config = ( - json.loads(self._proto.nodes[model_id].user_object.metadata)['config']) + config = json_utils.decode( + self._proto.nodes[model_id].user_object.metadata)['config'] if isinstance(model, models_lib.Sequential): if not isinstance(layers[0], input_layer.InputLayer): if 'batch_input_shape' in config['layers'][0]['config']: @@ -502,6 +542,14 @@ class KerasObjectLoader(tf_load.Loader): else: return inputs + def _config_node_setter(self, setter): + """Creates edges for nodes that are recreated from config.""" + def setattr_wrapper(obj, name, value): + # Avoid overwriting attributes of objects recreated from the config. + if obj._lookup_dependency(name) is None: # pylint: disable=protected-access + setter(obj, name, value) + return setattr_wrapper + def _finalize_saved_model_layers(layers): """Runs the final steps of loading Keras Layers from SavedModel.""" @@ -626,8 +674,9 @@ class RevivedLayer(object): with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access revived_obj._expects_training_arg = metadata['expects_training_arg'] - if metadata.get('config') is not None: - revived_obj._config = metadata['config'] + config = metadata.get('config)') + if generic_utils.validate_config(config): + revived_obj._config = config if metadata.get('input_spec') is not None: revived_obj.input_spec = recursively_deserialize_keras_object( metadata['input_spec'], @@ -747,8 +796,9 @@ class RevivedNetwork(RevivedLayer): with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access revived_obj._expects_training_arg = metadata['expects_training_arg'] - if metadata.get('config') is not None: - revived_obj._config = metadata['config'] + config = metadata.get('config') + if generic_utils.validate_config(config): + revived_obj._config = config if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( @@ -769,12 +819,13 @@ def _set_network_attributes_from_metadata(revived_obj): # pylint:enable=protected-access -def _add_serialized_attributes(layer, metadata): +def _maybe_add_serialized_attributes(layer, metadata): # Store attributes revived from SerializedAttributes in a un-tracked # dictionary. The attributes are the ones listed in CommonEndpoints or # "keras_api" for keras-specific attributes. - with trackable.no_automatic_dependency_tracking_scope(layer): - layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access + if not hasattr(layer, '_serialized_attributes'): + with trackable.no_automatic_dependency_tracking_scope(layer): + layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access def _get_keras_attr(layer): diff --git a/tensorflow/python/keras/saving/saved_model/revive_test.py b/tensorflow/python/keras/saving/saved_model/revive_test.py index 3e267340caa..ca3ecfc5a77 100644 --- a/tensorflow/python/keras/saving/saved_model/revive_test.py +++ b/tensorflow/python/keras/saving/saved_model/revive_test.py @@ -50,15 +50,18 @@ class SubclassedModelNoConfig(keras.Model): self.a = a self.b = b self.shared = CustomLayerNoConfig(a, b) - self.all_layers = [ + self.all_layers = [] + + def build(self, input_shape): + self.all_layers.extend([ self.shared, - CustomLayerWithConfig(a + 1, b + 2), - CustomLayerNoConfig(a + 3, b + 4), + CustomLayerWithConfig(self.a + 1, self.b + 2), + CustomLayerNoConfig(self.a + 3, self.b + 4), keras.Sequential([ # TODO(b/145029112): Bug with losses when there are shared layers. # self.shared, <-- Enable when bug is fixed. - CustomLayerNoConfig(a + 5, b + 6) - ])] + CustomLayerNoConfig(self.a + 5, self.b + 6)])]) + super(SubclassedModelNoConfig, self).build(input_shape) def call(self, inputs): x = inputs diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 9ee644bf8cd..bbb6155e30e 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -44,7 +44,7 @@ _GLOBAL_CUSTOM_NAMES = {} _SKIP_FAILED_SERIALIZATION = False # If a layer does not have a defined config, then the returned config will be a # dictionary with the below key. -LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config' +_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config' @keras_export('keras.utils.CustomObjectScope') @@ -271,7 +271,7 @@ def serialize_keras_object(instance): except NotImplementedError as e: if _SKIP_FAILED_SERIALIZATION: return serialize_keras_class_and_config( - name, {LAYER_UNDEFINED_CONFIG_KEY: True}) + name, {_LAYER_UNDEFINED_CONFIG_KEY: True}) raise e serialization_config = {} for key, item in config.items(): @@ -790,3 +790,8 @@ def validate_kwargs(kwargs, for kwarg in kwargs: if kwarg not in allowed_kwargs: raise TypeError(error_message, kwarg) + + +def validate_config(config): + """Determines whether config appears to be a valid layer config.""" + return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config diff --git a/tensorflow/python/saved_model/revived_types.py b/tensorflow/python/saved_model/revived_types.py index a802cdbe3ec..32d0b8ae53e 100644 --- a/tensorflow/python/saved_model/revived_types.py +++ b/tensorflow/python/saved_model/revived_types.py @@ -169,3 +169,13 @@ def deserialize(proto): def registered_identifiers(): return _REVIVED_TYPE_REGISTRY.keys() + + +def get_setter(proto): + _, type_registrations = _REVIVED_TYPE_REGISTRY.get( + proto.identifier, (None, None)) + if type_registrations is not None: + for type_registration in type_registrations: + if type_registration.should_load(proto): + return type_registration.setter + return None