Support saving uninitialized variables.

With this change, uninitialized variables can be handled by
`saved_model.save` in being saved to the graphs, but not to the
checkpoint. Loading this saved model in Python will fail, and loading
it in C++ will result in uninitialized variables, which the user must
then explicitly initialize.

The current behavior is to fail to save, so this is a
backwards-compatible behavior change. Moreover, uninitialized
variables are not available in the public API, so the behavior change
isn't likely to be noticed anyway.

PiperOrigin-RevId: 332529778
Change-Id: I6492112264344f7dccd9d7dff65501a4cc96a62b
This commit is contained in:
Cesar Crusius 2020-09-18 14:38:09 -07:00 committed by TensorFlower Gardener
parent 285cb597ed
commit d487b8c4ca
16 changed files with 261 additions and 16 deletions

View File

@ -225,7 +225,8 @@ Status FlattenSignature(const StructuredValue& signature,
}
const SavedObject* FindNodeAtPath(StringPiece path,
const SavedObjectGraph& object_graph) {
const SavedObjectGraph& object_graph,
int* node_id) {
const auto& nodes = object_graph.nodes();
if (nodes.empty()) {
return nullptr;
@ -245,6 +246,9 @@ const SavedObject* FindNodeAtPath(StringPiece path,
if (child_node_iter == current_node->children().end()) {
return nullptr;
}
if (node_id) {
*node_id = child_node_iter->node_id();
}
current_node = &nodes.Get(child_node_iter->node_id());
}

View File

@ -78,9 +78,11 @@ Status FlattenSignature(const StructuredValue& signature,
// Find the SavedObject in `object_graph` at location `path`. `path` must be
// a dot-delimited string of object names relative to the root object. If no
// object is found, returns nullptr. Callers must ensure `object_graph`
// outlives the returned pointer.
// outlives the returned pointer. If not `nullptr`, `node_id` will contain the
// index of the returned object in the `SavedObjectGraph.nodes` array.
const SavedObject* FindNodeAtPath(StringPiece path,
const SavedObjectGraph& object_graph);
const SavedObjectGraph& object_graph,
int* node_id = nullptr);
// Maps each node in `graphdef` to its corresponding Attribute Map.
// Callers must ensure that `graphdef` outlives the returned map.

View File

@ -268,6 +268,12 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
}
const std::string& checkpoint_key = attribute->checkpoint_key();
if (!bundle->variable_reader()->Contains(checkpoint_key)) {
LOG(WARNING) << "No checkpoint entry found for " << checkpoint_key
<< ". Variable will be uninitialized.";
return Status();
}
std::string variables_path_prefix =
io::JoinPath(directory, kSavedModelVariablesDirectory,
kSavedModelVariablesFilename);
@ -325,6 +331,31 @@ std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
return result;
}
Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
Variable** variable) {
int node_id;
const SavedObject* object = internal::FindNodeAtPath(
variable_path, bundle_.saved_object_graph(), &node_id);
if (object == nullptr) {
return errors::NotFound("No saved object found at path ", variable_path);
}
if (object->kind_case() == SavedObject::kVariable) {
auto iter = revived_objects_.find(node_id);
if (iter == revived_objects_.end()) {
return errors::Internal("Variable ", variable_path,
" was not properly revived.");
}
*variable = static_cast<Variable*>(iter->second.get());
return Status();
}
*variable = nullptr;
return errors::InvalidArgument(
variable_path, " is not a path to a Variable (kind=", object->kind_case(),
")");
}
TFSavedModelAPI::TFSavedModelAPI(
const std::string& directory, SavedModelV2Bundle bundle,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
@ -68,6 +69,8 @@ class TFSavedModelAPI : public SavedModelAPI {
~TFSavedModelAPI() override = default;
Status GetVariable(const std::string& variable_path, Variable** variable);
private:
TFSavedModelAPI(
const std::string& directory, SavedModelV2Bundle bundle,

View File

@ -245,14 +245,17 @@ tf_cc_test(
"saved_model_api_test.cc",
],
data = [
"//tensorflow/c/experimental/saved_model/internal/testdata:saved_models",
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
deps = [
":saved_model_api_type",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
"//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/core:lib",

View File

@ -21,6 +21,8 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
@ -194,6 +196,49 @@ TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
TFE_DeleteContext(ctx);
}
TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
bool use_tfrt = GetParam();
if (use_tfrt) {
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::string model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(),
"c/experimental/saved_model/internal/testdata/UninitializedVariable");
TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
tensorflow::TFSavedModelAPI* model_api =
tensorflow::down_cast<tensorflow::TFSavedModelAPI*>(
tensorflow::unwrap(saved_model));
tensorflow::Variable* uninitialized_variable;
ASSERT_EQ(tensorflow::Status::OK(),
model_api->GetVariable("uninitialized_variable",
&uninitialized_variable));
ASSERT_EQ(tensorflow::DT_FLOAT, uninitialized_variable->dtype());
ASSERT_EQ(tensorflow::Status::OK(),
model_api->GetVariable("sub_module.uninitialized_variable",
&uninitialized_variable));
ASSERT_EQ(tensorflow::DT_INT64, uninitialized_variable->dtype());
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
::testing::Bool());

View File

@ -0,0 +1,36 @@
load("//tensorflow:tensorflow.bzl", "py_strict_binary")
package(
licenses = ["notice"], # Apache 2.0
)
# Run this binary manually, with an argument pointing to the testdata/
# directory, to generate the test files used by the filegroup rule below.
py_strict_binary(
name = "gen_saved_models",
srcs = ["gen_saved_models.py"],
python_version = "PY3",
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:variables",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/module",
"//tensorflow/python/saved_model",
"//tensorflow/python/saved_model:save_options",
],
)
# Files generated by the binary above.
filegroup(
name = "saved_models",
srcs = glob([
"UninitializedVariable/**",
]),
visibility = [
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
],
)

View File

@ -0,0 +1,84 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Creates saved models used for testing.
This executable should be run with an argument pointing to the testdata/ folder
in this directory. It will re-generate the saved models that are used for
testing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import os
from tensorflow.python.compat import v2_compat
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.module import module
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.saved_model import saved_model
def _gen_uninitialized_variable(base_dir):
"""Generates a saved model with an uninitialized variable."""
class SubModule(module.Module):
"""A module with an UninitializedVariable."""
def __init__(self):
self.uninitialized_variable = resource_variable_ops.UninitializedVariable(
name="uninitialized_variable", dtype=dtypes.int64)
class Module(module.Module):
"""A module with an UninitializedVariable."""
def __init__(self):
super(Module, self).__init__()
self.sub_module = SubModule()
self.initialized_variable = variables.Variable(
1.0, name="initialized_variable")
# An UninitializedVariable with the same name as the variable in the
# SubModule, but with a different type.
self.uninitialized_variable = resource_variable_ops.UninitializedVariable(
name="uninitialized_variable", dtype=dtypes.float32)
@def_function.function(
input_signature=[tensor_spec.TensorSpec((), dtypes.float32)])
def compute(self, value):
return self.initialized_variable + value
to_save = Module()
saved_model.save(
to_save, export_dir=os.path.join(base_dir, "UninitializedVariable"))
def main(args):
if len(args) != 2:
raise app.UsageError("Expected one argument (base_dir).")
_, base_dir = args
_gen_uninitialized_variable(base_dir)
if __name__ == "__main__":
v2_compat.enable_v2_behavior()
app.run(main)

View File

@ -47,6 +47,7 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import save
@ -611,6 +612,31 @@ class SaveTest(test.TestCase, parameterized.TestCase):
experimental_variable_policy=save_options.VariablePolicy
.EXPAND_DISTRIBUTED_VARIABLES))
def test_save_uninitialized_variable(self):
root = tracking.AutoTrackable()
root.uninitialized_variable = resource_variable_ops.UninitializedVariable(
name="uninitialized_variable", dtype=dtypes.float32)
root.initialized_variable = variables.Variable(
1.0, name="initialized_variable")
# TODO(b/149594077): Python loading does not work now partly because it
# shouldn't, as the public API and semantics of uninitialized variables
# are not properly defined, and officially supporting loading would end up
# defining semantics "by usage." We should only allow loading once the API
# is made official.
export_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(root, export_dir)
with self.assertRaisesRegex(FileNotFoundError,
"Key uninitialized_variable"):
load.load(export_dir)
with ops.Graph().as_default(), session_lib.Session() as session:
# The final ValueError here (with "no variables to save") is confusing,
# but errors upstream give the user the correct information (a
# NotFoundError stating that the uninitalized_variable was not found in
# the checkpoint).
with self.assertRaises(ValueError):
loader.load(session, [tag_constants.SERVING], export_dir)
class VariablePolicyEnumTest(test.TestCase):
@ -820,8 +846,7 @@ class AssetTests(test.TestCase):
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
value_dtype=dtypes.int64,
value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
table = lookup_ops.HashTable(
initializer, default_value=-1)
table = lookup_ops.HashTable(initializer, default_value=-1)
root.table_user = def_function.function(
table.lookup,
input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])

View File

@ -72,9 +72,14 @@ class _SingleDeviceSaver(object):
tensor_slices = []
for saveable in self._saveable_objects:
for spec in saveable.specs:
tensor_names.append(spec.name)
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)
tensor = spec.tensor
# A tensor value of `None` indicates that this SaveableObject gets
# recorded in the object graph, but that no value is saved in the
# checkpoint.
if tensor is not None:
tensor_names.append(spec.name)
tensors.append(tensor)
tensor_slices.append(spec.slice_spec)
save_device = options.experimental_io_device or "cpu:0"
with ops.device(save_device):
return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)

View File

@ -26,6 +26,7 @@ class SaveSpec(object):
Args:
tensor: the tensor to save or callable that produces a tensor to save.
If the value is `None`, the `SaveSpec` is ignored.
slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`.
name: the name to save the tensor under.
dtype: The data type of the Tensor. Required if `tensor` is callable.

View File

@ -99,11 +99,16 @@ class ResourceVariableSaveable(saveable_object.SaveableObject):
def _read_variable_closure(v):
def f():
with ops.device(v.device):
if context.executing_eagerly() and not v.is_initialized():
# A SaveSpec tensor value of `None` indicates that the variable is
# uninitialized.
return None
x = v.read_value()
# To allow variables placed on non-CPU devices to be checkpointed,
# we copy them to CPU on the same machine first.
with ops.device("/device:CPU:0"):
return array_ops.identity(x)
return f
self.handle_op = var.handle
@ -177,8 +182,8 @@ def saveable_objects_for_op(op, name):
yield ReferenceVariableSaveable(
variable, variable._save_slice_info.spec, name)
else:
yield ResourceVariableSaveable(
variable, variable._save_slice_info.spec, name)
yield ResourceVariableSaveable(variable, variable._save_slice_info.spec,
name)
# pylint: enable=protected-access
elif isinstance(op, trackable.Trackable) and not isinstance(
op, variables.Variable):
@ -196,12 +201,10 @@ def saveable_objects_for_op(op, name):
else:
# A variable or tensor.
if isinstance(op, resource_variable_ops.BaseResourceVariable):
# pylint: disable=protected-access
if op._in_graph_mode:
variable = op._graph_element
if op._in_graph_mode: # pylint: disable=protected-access
variable = op._graph_element # pylint: disable=protected-access
else:
variable = op
# pylint: enable=protected-access
yield ResourceVariableSaveable(variable, "", name)
else:
if context.executing_eagerly():
@ -217,8 +220,7 @@ def saveable_objects_for_op(op, name):
"AutoReloadVariable"]:
yield ReferenceVariableSaveable(variable, "", name)
else:
yield ResourceVariableSaveable(
variable, "", name)
yield ResourceVariableSaveable(variable, "", name)
def op_list_to_dict(op_list, convert_variable_to_tensor=True):

View File

@ -1897,6 +1897,10 @@ register_extension_info(
label_regex_for_dep = "{extension_name}",
)
# Placeholder to use until bazel supports py_strict_binary.
def py_strict_binary(name, **kwargs):
native.py_binary(name = name, **kwargs)
# Placeholder to use until bazel supports py_strict_library.
def py_strict_library(name, **kwargs):
native.py_library(name = name, **kwargs)