Add support for Tensorflow Saveables as Keras weights.

PiperOrigin-RevId: 272063211
This commit is contained in:
A. Unique TensorFlower 2019-09-30 14:15:10 -07:00 committed by TensorFlower Gardener
parent 0eaefe9fdc
commit 0175008120
4 changed files with 218 additions and 18 deletions

View File

@ -1668,6 +1668,20 @@ tf_py_test(
],
)
tf_py_test(
name = "base_layer_utils_test",
srcs = ["engine/base_layer_utils_test.py"],
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:client_testlib",
],
tags = [
"no_rocm",
"nomac", # TODO(mihaimaruseac): b/127695564
],
)
tf_py_test(
name = "control_flow_test",
size = "medium",

View File

@ -394,6 +394,27 @@ class Layer(module.Module):
"""
return inputs
@doc_controls.for_subclass_implementers
def _add_trackable(self, trackable_object, trainable):
"""Adds a Trackable object to this layer's state.
Arguments:
trackable_object: The tf.tracking.Trackable object to add.
trainable: Boolean, whether the variable should be part of the layer's
"trainable_variables" (e.g. variables, biases) or
"non_trainable_variables" (e.g. BatchNorm mean and variance).
Returns:
The tf.tracking.Trackable object.
"""
if trainable:
self._trainable_weights.append(
base_layer_utils.TrackableWeightHandler(trackable_object))
else:
self._non_trainable_weights.append(
base_layer_utils.TrackableWeightHandler(trackable_object))
return trackable_object
@doc_controls.for_subclass_implementers
def add_weight(self,
name=None,
@ -1333,23 +1354,40 @@ class Layer(module.Module):
ValueError: If the provided weights list does not match the
layer's specifications.
"""
params = self.weights
if len(params) != len(weights):
raise ValueError('You called `set_weights(weights)` on layer "' +
self.name + '" with a weight list of length ' +
str(len(weights)) + ', but the layer was expecting ' +
str(len(params)) + ' weights. Provided weights: ' +
str(weights)[:50] + '...')
if not params:
return
params = self.trainable_weights + self.non_trainable_weights
expected_num_weights = 0
for param in params:
if isinstance(param, base_layer_utils.TrackableWeightHandler):
expected_num_weights += param.num_tensors
else:
expected_num_weights += 1
if expected_num_weights != len(weights):
raise ValueError(
'You called `set_weights(weights)` on layer "%s" '
'with a weight list of length %s, but the layer was '
'expecting %s weights. Provided weights: %s...' %
(self.name, len(weights), expected_num_weights, str(weights)[:50]))
weight_index = 0
weight_value_tuples = []
for p, w in zip(params, weights):
ref_shape = p.shape
if not ref_shape.is_compatible_with(w.shape):
raise ValueError('Layer weight shape ' + str(ref_shape) +
' not compatible with '
'provided weight shape ' + str(w.shape))
weight_value_tuples.append((p, w))
for param in params:
if isinstance(param, base_layer_utils.TrackableWeightHandler):
num_tensors = param.num_tensors
tensors = weights[weight_index:weight_index + num_tensors]
param.set_weights(tensors)
weight_index += num_tensors
else:
weight = weights[weight_index]
ref_shape = param.shape
if not ref_shape.is_compatible_with(weight.shape):
raise ValueError(
'Layer weight shape %s not compatible with provided weight '
'shape %s' % (ref_shape, weight.shape))
weight_value_tuples.append((param, weight))
weight_index += 1
backend.batch_set_value(weight_value_tuples)
def get_weights(self):
@ -1358,8 +1396,14 @@ class Layer(module.Module):
Returns:
Weights values as a list of numpy arrays.
"""
params = self.weights
return backend.batch_get_value(params)
weights = self.trainable_weights + self.non_trainable_weights
output_weights = []
for weight in weights:
if isinstance(weight, base_layer_utils.TrackableWeightHandler):
output_weights.extend(weight.get_tensors())
else:
output_weights.append(weight)
return backend.batch_get_value(output_weights)
def get_updates_for(self, inputs):
"""Retrieves updates relevant to a specific set of inputs.

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import control_flow_v2_func_graphs
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.tracking import base as tracking
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
@ -692,3 +693,74 @@ def v2_dtype_behavior_enabled():
if V2_DTYPE_BEHAVIOR is None:
return tf2.enabled()
return V2_DTYPE_BEHAVIOR
class TrackableWeightHandler(object):
"""Keras wrapper for handling tracking.Trackable object saving and restoring.
This class handles Trackables in both V1 and V2 modes, ensuring that they can
be saved and restored with the correct data and without adding additional ops
on every save.
Attributes:
trackable: The trackable to wrap.
num_tensors: The number of tensors that this trackable requires for saving.
"""
def __init__(self, trackable):
if not isinstance(trackable, tracking.Trackable):
raise ValueError('%s is not a Trackable object.' % (trackable,))
self._trackable = trackable
# TODO(b/141682913): Figure out why this is private and fix it.
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
if len(saveables) != 1:
raise ValueError('Only Trackables with one Saveable are supported.')
saveable = list(saveables)[0]
if ops.executing_eagerly_outside_functions():
# If we're in eager mode, we need to defer calling the Trackable's
# saveable() callable until data export time.
# However, it is safe to call the saveable as many times as we want, so
# we will call it now to figure out how many tensors this Trackable will
# produce.
self._saveable = saveable
self._num_tensors = len(self._saveable().specs)
self._setter = lambda weights: self._saveable().restore(weights, None)
self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
else:
# If we're in Graph mode, we need to evaluate the Saveable only once and
# cache the resulting restore graph. Failing to do this will result in
# new assignment ops being added to the graph each time set_weights() is
# called.
self._placeholder_tensors = []
self._saveable = saveable()
self._num_tensors = len(self._saveable.specs)
for spec in self._saveable.specs:
tensor = spec.tensor
self._placeholder_tensors.append(
array_ops.placeholder(tensor.dtype, tensor.shape))
self._assign_op = self._saveable.restore(self._placeholder_tensors, None)
self._setter = self._set_weights_v1
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
@property
def num_tensors(self):
return self._num_tensors
def set_weights(self, weights):
if len(weights) != self._num_tensors:
raise ValueError(
('Weight handler for trackable %s received the wrong number of ' +
'weights: expected %s, got %s.') %
(self._trackable, self._num_tensors, len(weights)))
self._setter(weights)
def get_tensors(self):
return self._getter()
def _set_weights_v1(self, weights):
feed_dict = {}
for idx, tensor in enumerate(weights):
feed_dict[self._placeholder_tensors[idx]] = tensor
backend.get_session().run(self._assign_op, feed_dict)

View File

@ -0,0 +1,70 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.ops import lookup_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class TrackableWeightHandlerTest(keras_parameterized.TestCase):
def get_table_handler(self):
# Note: There is some repetition in these tests' setup. However, Tensorflow
# does not play nicely with a separate setUp() call (causing errors related
# to graph building), so we have to use a called setup instead of a setUp()
# call.
table = lookup_ops.MutableHashTable(
key_dtype=dtypes.string, value_dtype=dtypes.int32, default_value=0)
return base_layer_utils.TrackableWeightHandler(table)
def test_get_num_tensors(self):
table_handler = self.get_table_handler()
self.assertEqual(2, table_handler.num_tensors)
def test_get_and_set_weights(self):
table_handler = self.get_table_handler()
table_data = {b"a": 1, b"b": 2, b"c": 3}
table_handler.set_weights(
[list(table_data.keys()),
list(table_data.values())])
weights = backend.batch_get_value(table_handler.get_tensors())
weight_data = {key: value for key, value in zip(weights[0], weights[1])}
self.assertDictEqual(table_data, weight_data)
def test_get_and_set_weights_does_not_add_ops(self):
table_handler = self.get_table_handler()
table_data = {b"a": 1, b"b": 2, b"c": 3}
table_handler.set_weights(
[list(table_data.keys()),
list(table_data.values())])
_ = backend.batch_get_value(table_handler.get_tensors())
backend.get_session().graph.finalize()
table_handler.set_weights(
[list(table_data.keys()),
list(table_data.values())])
_ = backend.batch_get_value(table_handler.get_tensors())
if __name__ == "__main__":
test.main()