Add support for Tensorflow Saveables as Keras weights.
PiperOrigin-RevId: 272063211
This commit is contained in:
parent
0eaefe9fdc
commit
0175008120
@ -1668,6 +1668,20 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "base_layer_utils_test",
|
||||
srcs = ["engine/base_layer_utils_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "control_flow_test",
|
||||
size = "medium",
|
||||
|
@ -394,6 +394,27 @@ class Layer(module.Module):
|
||||
"""
|
||||
return inputs
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
def _add_trackable(self, trackable_object, trainable):
|
||||
"""Adds a Trackable object to this layer's state.
|
||||
|
||||
Arguments:
|
||||
trackable_object: The tf.tracking.Trackable object to add.
|
||||
trainable: Boolean, whether the variable should be part of the layer's
|
||||
"trainable_variables" (e.g. variables, biases) or
|
||||
"non_trainable_variables" (e.g. BatchNorm mean and variance).
|
||||
|
||||
Returns:
|
||||
The tf.tracking.Trackable object.
|
||||
"""
|
||||
if trainable:
|
||||
self._trainable_weights.append(
|
||||
base_layer_utils.TrackableWeightHandler(trackable_object))
|
||||
else:
|
||||
self._non_trainable_weights.append(
|
||||
base_layer_utils.TrackableWeightHandler(trackable_object))
|
||||
return trackable_object
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
def add_weight(self,
|
||||
name=None,
|
||||
@ -1333,23 +1354,40 @@ class Layer(module.Module):
|
||||
ValueError: If the provided weights list does not match the
|
||||
layer's specifications.
|
||||
"""
|
||||
params = self.weights
|
||||
if len(params) != len(weights):
|
||||
raise ValueError('You called `set_weights(weights)` on layer "' +
|
||||
self.name + '" with a weight list of length ' +
|
||||
str(len(weights)) + ', but the layer was expecting ' +
|
||||
str(len(params)) + ' weights. Provided weights: ' +
|
||||
str(weights)[:50] + '...')
|
||||
if not params:
|
||||
return
|
||||
params = self.trainable_weights + self.non_trainable_weights
|
||||
|
||||
expected_num_weights = 0
|
||||
for param in params:
|
||||
if isinstance(param, base_layer_utils.TrackableWeightHandler):
|
||||
expected_num_weights += param.num_tensors
|
||||
else:
|
||||
expected_num_weights += 1
|
||||
|
||||
if expected_num_weights != len(weights):
|
||||
raise ValueError(
|
||||
'You called `set_weights(weights)` on layer "%s" '
|
||||
'with a weight list of length %s, but the layer was '
|
||||
'expecting %s weights. Provided weights: %s...' %
|
||||
(self.name, len(weights), expected_num_weights, str(weights)[:50]))
|
||||
|
||||
weight_index = 0
|
||||
weight_value_tuples = []
|
||||
for p, w in zip(params, weights):
|
||||
ref_shape = p.shape
|
||||
if not ref_shape.is_compatible_with(w.shape):
|
||||
raise ValueError('Layer weight shape ' + str(ref_shape) +
|
||||
' not compatible with '
|
||||
'provided weight shape ' + str(w.shape))
|
||||
weight_value_tuples.append((p, w))
|
||||
for param in params:
|
||||
if isinstance(param, base_layer_utils.TrackableWeightHandler):
|
||||
num_tensors = param.num_tensors
|
||||
tensors = weights[weight_index:weight_index + num_tensors]
|
||||
param.set_weights(tensors)
|
||||
weight_index += num_tensors
|
||||
else:
|
||||
weight = weights[weight_index]
|
||||
ref_shape = param.shape
|
||||
if not ref_shape.is_compatible_with(weight.shape):
|
||||
raise ValueError(
|
||||
'Layer weight shape %s not compatible with provided weight '
|
||||
'shape %s' % (ref_shape, weight.shape))
|
||||
weight_value_tuples.append((param, weight))
|
||||
weight_index += 1
|
||||
|
||||
backend.batch_set_value(weight_value_tuples)
|
||||
|
||||
def get_weights(self):
|
||||
@ -1358,8 +1396,14 @@ class Layer(module.Module):
|
||||
Returns:
|
||||
Weights values as a list of numpy arrays.
|
||||
"""
|
||||
params = self.weights
|
||||
return backend.batch_get_value(params)
|
||||
weights = self.trainable_weights + self.non_trainable_weights
|
||||
output_weights = []
|
||||
for weight in weights:
|
||||
if isinstance(weight, base_layer_utils.TrackableWeightHandler):
|
||||
output_weights.extend(weight.get_tensors())
|
||||
else:
|
||||
output_weights.append(weight)
|
||||
return backend.batch_get_value(output_weights)
|
||||
|
||||
def get_updates_for(self, inputs):
|
||||
"""Retrieves updates relevant to a specific set of inputs.
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.ops import control_flow_v2_func_graphs
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.training.tracking import base as tracking
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
@ -692,3 +693,74 @@ def v2_dtype_behavior_enabled():
|
||||
if V2_DTYPE_BEHAVIOR is None:
|
||||
return tf2.enabled()
|
||||
return V2_DTYPE_BEHAVIOR
|
||||
|
||||
|
||||
class TrackableWeightHandler(object):
|
||||
"""Keras wrapper for handling tracking.Trackable object saving and restoring.
|
||||
|
||||
This class handles Trackables in both V1 and V2 modes, ensuring that they can
|
||||
be saved and restored with the correct data and without adding additional ops
|
||||
on every save.
|
||||
|
||||
Attributes:
|
||||
trackable: The trackable to wrap.
|
||||
num_tensors: The number of tensors that this trackable requires for saving.
|
||||
"""
|
||||
|
||||
def __init__(self, trackable):
|
||||
if not isinstance(trackable, tracking.Trackable):
|
||||
raise ValueError('%s is not a Trackable object.' % (trackable,))
|
||||
self._trackable = trackable
|
||||
|
||||
# TODO(b/141682913): Figure out why this is private and fix it.
|
||||
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
|
||||
if len(saveables) != 1:
|
||||
raise ValueError('Only Trackables with one Saveable are supported.')
|
||||
saveable = list(saveables)[0]
|
||||
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# If we're in eager mode, we need to defer calling the Trackable's
|
||||
# saveable() callable until data export time.
|
||||
# However, it is safe to call the saveable as many times as we want, so
|
||||
# we will call it now to figure out how many tensors this Trackable will
|
||||
# produce.
|
||||
self._saveable = saveable
|
||||
self._num_tensors = len(self._saveable().specs)
|
||||
self._setter = lambda weights: self._saveable().restore(weights, None)
|
||||
self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
|
||||
else:
|
||||
# If we're in Graph mode, we need to evaluate the Saveable only once and
|
||||
# cache the resulting restore graph. Failing to do this will result in
|
||||
# new assignment ops being added to the graph each time set_weights() is
|
||||
# called.
|
||||
self._placeholder_tensors = []
|
||||
self._saveable = saveable()
|
||||
self._num_tensors = len(self._saveable.specs)
|
||||
for spec in self._saveable.specs:
|
||||
tensor = spec.tensor
|
||||
self._placeholder_tensors.append(
|
||||
array_ops.placeholder(tensor.dtype, tensor.shape))
|
||||
self._assign_op = self._saveable.restore(self._placeholder_tensors, None)
|
||||
self._setter = self._set_weights_v1
|
||||
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
|
||||
|
||||
@property
|
||||
def num_tensors(self):
|
||||
return self._num_tensors
|
||||
|
||||
def set_weights(self, weights):
|
||||
if len(weights) != self._num_tensors:
|
||||
raise ValueError(
|
||||
('Weight handler for trackable %s received the wrong number of ' +
|
||||
'weights: expected %s, got %s.') %
|
||||
(self._trackable, self._num_tensors, len(weights)))
|
||||
self._setter(weights)
|
||||
|
||||
def get_tensors(self):
|
||||
return self._getter()
|
||||
|
||||
def _set_weights_v1(self, weights):
|
||||
feed_dict = {}
|
||||
for idx, tensor in enumerate(weights):
|
||||
feed_dict[self._placeholder_tensors[idx]] = tensor
|
||||
backend.get_session().run(self._assign_op, feed_dict)
|
||||
|
70
tensorflow/python/keras/engine/base_layer_utils_test.py
Normal file
70
tensorflow/python/keras/engine/base_layer_utils_test.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TrackableWeightHandlerTest(keras_parameterized.TestCase):
|
||||
|
||||
def get_table_handler(self):
|
||||
# Note: There is some repetition in these tests' setup. However, Tensorflow
|
||||
# does not play nicely with a separate setUp() call (causing errors related
|
||||
# to graph building), so we have to use a called setup instead of a setUp()
|
||||
# call.
|
||||
table = lookup_ops.MutableHashTable(
|
||||
key_dtype=dtypes.string, value_dtype=dtypes.int32, default_value=0)
|
||||
return base_layer_utils.TrackableWeightHandler(table)
|
||||
|
||||
def test_get_num_tensors(self):
|
||||
table_handler = self.get_table_handler()
|
||||
self.assertEqual(2, table_handler.num_tensors)
|
||||
|
||||
def test_get_and_set_weights(self):
|
||||
table_handler = self.get_table_handler()
|
||||
|
||||
table_data = {b"a": 1, b"b": 2, b"c": 3}
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
weights = backend.batch_get_value(table_handler.get_tensors())
|
||||
weight_data = {key: value for key, value in zip(weights[0], weights[1])}
|
||||
self.assertDictEqual(table_data, weight_data)
|
||||
|
||||
def test_get_and_set_weights_does_not_add_ops(self):
|
||||
table_handler = self.get_table_handler()
|
||||
table_data = {b"a": 1, b"b": 2, b"c": 3}
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
_ = backend.batch_get_value(table_handler.get_tensors())
|
||||
backend.get_session().graph.finalize()
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
_ = backend.batch_get_value(table_handler.get_tensors())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user