Support saving uninitialized variables.
With this change, uninitialized variables can be handled by `saved_model.save` in being saved to the graphs, but not to the checkpoint. Loading this saved model in Python will fail, and loading it in C++ will result in uninitialized variables, which the user must then explicitly initialize. The current behavior is to fail to save, so this is a backwards-compatible behavior change. Moreover, uninitialized variables are not available in the public API, so the behavior change isn't likely to be noticed anyway. PiperOrigin-RevId: 332529778 Change-Id: I6492112264344f7dccd9d7dff65501a4cc96a62b
This commit is contained in:
parent
285cb597ed
commit
d487b8c4ca
@ -225,7 +225,8 @@ Status FlattenSignature(const StructuredValue& signature,
|
||||
}
|
||||
|
||||
const SavedObject* FindNodeAtPath(StringPiece path,
|
||||
const SavedObjectGraph& object_graph) {
|
||||
const SavedObjectGraph& object_graph,
|
||||
int* node_id) {
|
||||
const auto& nodes = object_graph.nodes();
|
||||
if (nodes.empty()) {
|
||||
return nullptr;
|
||||
@ -245,6 +246,9 @@ const SavedObject* FindNodeAtPath(StringPiece path,
|
||||
if (child_node_iter == current_node->children().end()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (node_id) {
|
||||
*node_id = child_node_iter->node_id();
|
||||
}
|
||||
current_node = &nodes.Get(child_node_iter->node_id());
|
||||
}
|
||||
|
||||
|
@ -78,9 +78,11 @@ Status FlattenSignature(const StructuredValue& signature,
|
||||
// Find the SavedObject in `object_graph` at location `path`. `path` must be
|
||||
// a dot-delimited string of object names relative to the root object. If no
|
||||
// object is found, returns nullptr. Callers must ensure `object_graph`
|
||||
// outlives the returned pointer.
|
||||
// outlives the returned pointer. If not `nullptr`, `node_id` will contain the
|
||||
// index of the returned object in the `SavedObjectGraph.nodes` array.
|
||||
const SavedObject* FindNodeAtPath(StringPiece path,
|
||||
const SavedObjectGraph& object_graph);
|
||||
const SavedObjectGraph& object_graph,
|
||||
int* node_id = nullptr);
|
||||
|
||||
// Maps each node in `graphdef` to its corresponding Attribute Map.
|
||||
// Callers must ensure that `graphdef` outlives the returned map.
|
||||
|
@ -268,6 +268,12 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
|
||||
}
|
||||
|
||||
const std::string& checkpoint_key = attribute->checkpoint_key();
|
||||
if (!bundle->variable_reader()->Contains(checkpoint_key)) {
|
||||
LOG(WARNING) << "No checkpoint entry found for " << checkpoint_key
|
||||
<< ". Variable will be uninitialized.";
|
||||
return Status();
|
||||
}
|
||||
|
||||
std::string variables_path_prefix =
|
||||
io::JoinPath(directory, kSavedModelVariablesDirectory,
|
||||
kSavedModelVariablesFilename);
|
||||
@ -325,6 +331,31 @@ std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||
return result;
|
||||
}
|
||||
|
||||
Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
|
||||
Variable** variable) {
|
||||
int node_id;
|
||||
const SavedObject* object = internal::FindNodeAtPath(
|
||||
variable_path, bundle_.saved_object_graph(), &node_id);
|
||||
if (object == nullptr) {
|
||||
return errors::NotFound("No saved object found at path ", variable_path);
|
||||
}
|
||||
|
||||
if (object->kind_case() == SavedObject::kVariable) {
|
||||
auto iter = revived_objects_.find(node_id);
|
||||
if (iter == revived_objects_.end()) {
|
||||
return errors::Internal("Variable ", variable_path,
|
||||
" was not properly revived.");
|
||||
}
|
||||
*variable = static_cast<Variable*>(iter->second.get());
|
||||
return Status();
|
||||
}
|
||||
|
||||
*variable = nullptr;
|
||||
return errors::InvalidArgument(
|
||||
variable_path, " is not a path to a Variable (kind=", object->kind_case(),
|
||||
")");
|
||||
}
|
||||
|
||||
TFSavedModelAPI::TFSavedModelAPI(
|
||||
const std::string& directory, SavedModelV2Bundle bundle,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
@ -68,6 +69,8 @@ class TFSavedModelAPI : public SavedModelAPI {
|
||||
|
||||
~TFSavedModelAPI() override = default;
|
||||
|
||||
Status GetVariable(const std::string& variable_path, Variable** variable);
|
||||
|
||||
private:
|
||||
TFSavedModelAPI(
|
||||
const std::string& directory, SavedModelV2Bundle bundle,
|
||||
|
@ -245,14 +245,17 @@ tf_cc_test(
|
||||
"saved_model_api_test.cc",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/c/experimental/saved_model/internal/testdata:saved_models",
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
deps = [
|
||||
":saved_model_api_type",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api_test_util",
|
||||
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
@ -194,6 +196,49 @@ TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = tensorflow::io::JoinPath(
|
||||
tensorflow::testing::TensorFlowSrcRoot(),
|
||||
"c/experimental/saved_model/internal/testdata/UninitializedVariable");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
tensorflow::TFSavedModelAPI* model_api =
|
||||
tensorflow::down_cast<tensorflow::TFSavedModelAPI*>(
|
||||
tensorflow::unwrap(saved_model));
|
||||
tensorflow::Variable* uninitialized_variable;
|
||||
ASSERT_EQ(tensorflow::Status::OK(),
|
||||
model_api->GetVariable("uninitialized_variable",
|
||||
&uninitialized_variable));
|
||||
ASSERT_EQ(tensorflow::DT_FLOAT, uninitialized_variable->dtype());
|
||||
|
||||
ASSERT_EQ(tensorflow::Status::OK(),
|
||||
model_api->GetVariable("sub_module.uninitialized_variable",
|
||||
&uninitialized_variable));
|
||||
ASSERT_EQ(tensorflow::DT_INT64, uninitialized_variable->dtype());
|
||||
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
|
||||
::testing::Bool());
|
||||
|
||||
|
36
tensorflow/c/experimental/saved_model/internal/testdata/BUILD
vendored
Normal file
36
tensorflow/c/experimental/saved_model/internal/testdata/BUILD
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
load("//tensorflow:tensorflow.bzl", "py_strict_binary")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Run this binary manually, with an argument pointing to the testdata/
|
||||
# directory, to generate the test files used by the filegroup rule below.
|
||||
py_strict_binary(
|
||||
name = "gen_saved_models",
|
||||
srcs = ["gen_saved_models.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model:save_options",
|
||||
],
|
||||
)
|
||||
|
||||
# Files generated by the binary above.
|
||||
filegroup(
|
||||
name = "saved_models",
|
||||
srcs = glob([
|
||||
"UninitializedVariable/**",
|
||||
]),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
|
||||
],
|
||||
)
|
BIN
tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb
vendored
Normal file
BIN
tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
84
tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py
vendored
Normal file
84
tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py
vendored
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Lint as: python3
|
||||
"""Creates saved models used for testing.
|
||||
|
||||
This executable should be run with an argument pointing to the testdata/ folder
|
||||
in this directory. It will re-generate the saved models that are used for
|
||||
testing.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
|
||||
|
||||
def _gen_uninitialized_variable(base_dir):
|
||||
"""Generates a saved model with an uninitialized variable."""
|
||||
|
||||
class SubModule(module.Module):
|
||||
"""A module with an UninitializedVariable."""
|
||||
|
||||
def __init__(self):
|
||||
self.uninitialized_variable = resource_variable_ops.UninitializedVariable(
|
||||
name="uninitialized_variable", dtype=dtypes.int64)
|
||||
|
||||
class Module(module.Module):
|
||||
"""A module with an UninitializedVariable."""
|
||||
|
||||
def __init__(self):
|
||||
super(Module, self).__init__()
|
||||
self.sub_module = SubModule()
|
||||
self.initialized_variable = variables.Variable(
|
||||
1.0, name="initialized_variable")
|
||||
# An UninitializedVariable with the same name as the variable in the
|
||||
# SubModule, but with a different type.
|
||||
self.uninitialized_variable = resource_variable_ops.UninitializedVariable(
|
||||
name="uninitialized_variable", dtype=dtypes.float32)
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec((), dtypes.float32)])
|
||||
def compute(self, value):
|
||||
return self.initialized_variable + value
|
||||
|
||||
to_save = Module()
|
||||
saved_model.save(
|
||||
to_save, export_dir=os.path.join(base_dir, "UninitializedVariable"))
|
||||
|
||||
|
||||
def main(args):
|
||||
if len(args) != 2:
|
||||
raise app.UsageError("Expected one argument (base_dir).")
|
||||
_, base_dir = args
|
||||
_gen_uninitialized_variable(base_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
v2_compat.enable_v2_behavior()
|
||||
app.run(main)
|
@ -47,6 +47,7 @@ from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import save
|
||||
@ -611,6 +612,31 @@ class SaveTest(test.TestCase, parameterized.TestCase):
|
||||
experimental_variable_policy=save_options.VariablePolicy
|
||||
.EXPAND_DISTRIBUTED_VARIABLES))
|
||||
|
||||
def test_save_uninitialized_variable(self):
|
||||
root = tracking.AutoTrackable()
|
||||
root.uninitialized_variable = resource_variable_ops.UninitializedVariable(
|
||||
name="uninitialized_variable", dtype=dtypes.float32)
|
||||
root.initialized_variable = variables.Variable(
|
||||
1.0, name="initialized_variable")
|
||||
|
||||
# TODO(b/149594077): Python loading does not work now partly because it
|
||||
# shouldn't, as the public API and semantics of uninitialized variables
|
||||
# are not properly defined, and officially supporting loading would end up
|
||||
# defining semantics "by usage." We should only allow loading once the API
|
||||
# is made official.
|
||||
export_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(root, export_dir)
|
||||
with self.assertRaisesRegex(FileNotFoundError,
|
||||
"Key uninitialized_variable"):
|
||||
load.load(export_dir)
|
||||
with ops.Graph().as_default(), session_lib.Session() as session:
|
||||
# The final ValueError here (with "no variables to save") is confusing,
|
||||
# but errors upstream give the user the correct information (a
|
||||
# NotFoundError stating that the uninitalized_variable was not found in
|
||||
# the checkpoint).
|
||||
with self.assertRaises(ValueError):
|
||||
loader.load(session, [tag_constants.SERVING], export_dir)
|
||||
|
||||
|
||||
class VariablePolicyEnumTest(test.TestCase):
|
||||
|
||||
@ -820,8 +846,7 @@ class AssetTests(test.TestCase):
|
||||
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||
value_dtype=dtypes.int64,
|
||||
value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
|
||||
table = lookup_ops.HashTable(
|
||||
initializer, default_value=-1)
|
||||
table = lookup_ops.HashTable(initializer, default_value=-1)
|
||||
root.table_user = def_function.function(
|
||||
table.lookup,
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
|
||||
|
@ -72,9 +72,14 @@ class _SingleDeviceSaver(object):
|
||||
tensor_slices = []
|
||||
for saveable in self._saveable_objects:
|
||||
for spec in saveable.specs:
|
||||
tensor_names.append(spec.name)
|
||||
tensors.append(spec.tensor)
|
||||
tensor_slices.append(spec.slice_spec)
|
||||
tensor = spec.tensor
|
||||
# A tensor value of `None` indicates that this SaveableObject gets
|
||||
# recorded in the object graph, but that no value is saved in the
|
||||
# checkpoint.
|
||||
if tensor is not None:
|
||||
tensor_names.append(spec.name)
|
||||
tensors.append(tensor)
|
||||
tensor_slices.append(spec.slice_spec)
|
||||
save_device = options.experimental_io_device or "cpu:0"
|
||||
with ops.device(save_device):
|
||||
return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
|
||||
|
@ -26,6 +26,7 @@ class SaveSpec(object):
|
||||
|
||||
Args:
|
||||
tensor: the tensor to save or callable that produces a tensor to save.
|
||||
If the value is `None`, the `SaveSpec` is ignored.
|
||||
slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`.
|
||||
name: the name to save the tensor under.
|
||||
dtype: The data type of the Tensor. Required if `tensor` is callable.
|
||||
|
@ -99,11 +99,16 @@ class ResourceVariableSaveable(saveable_object.SaveableObject):
|
||||
def _read_variable_closure(v):
|
||||
def f():
|
||||
with ops.device(v.device):
|
||||
if context.executing_eagerly() and not v.is_initialized():
|
||||
# A SaveSpec tensor value of `None` indicates that the variable is
|
||||
# uninitialized.
|
||||
return None
|
||||
x = v.read_value()
|
||||
# To allow variables placed on non-CPU devices to be checkpointed,
|
||||
# we copy them to CPU on the same machine first.
|
||||
with ops.device("/device:CPU:0"):
|
||||
return array_ops.identity(x)
|
||||
|
||||
return f
|
||||
|
||||
self.handle_op = var.handle
|
||||
@ -177,8 +182,8 @@ def saveable_objects_for_op(op, name):
|
||||
yield ReferenceVariableSaveable(
|
||||
variable, variable._save_slice_info.spec, name)
|
||||
else:
|
||||
yield ResourceVariableSaveable(
|
||||
variable, variable._save_slice_info.spec, name)
|
||||
yield ResourceVariableSaveable(variable, variable._save_slice_info.spec,
|
||||
name)
|
||||
# pylint: enable=protected-access
|
||||
elif isinstance(op, trackable.Trackable) and not isinstance(
|
||||
op, variables.Variable):
|
||||
@ -196,12 +201,10 @@ def saveable_objects_for_op(op, name):
|
||||
else:
|
||||
# A variable or tensor.
|
||||
if isinstance(op, resource_variable_ops.BaseResourceVariable):
|
||||
# pylint: disable=protected-access
|
||||
if op._in_graph_mode:
|
||||
variable = op._graph_element
|
||||
if op._in_graph_mode: # pylint: disable=protected-access
|
||||
variable = op._graph_element # pylint: disable=protected-access
|
||||
else:
|
||||
variable = op
|
||||
# pylint: enable=protected-access
|
||||
yield ResourceVariableSaveable(variable, "", name)
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
@ -217,8 +220,7 @@ def saveable_objects_for_op(op, name):
|
||||
"AutoReloadVariable"]:
|
||||
yield ReferenceVariableSaveable(variable, "", name)
|
||||
else:
|
||||
yield ResourceVariableSaveable(
|
||||
variable, "", name)
|
||||
yield ResourceVariableSaveable(variable, "", name)
|
||||
|
||||
|
||||
def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
||||
|
@ -1897,6 +1897,10 @@ register_extension_info(
|
||||
label_regex_for_dep = "{extension_name}",
|
||||
)
|
||||
|
||||
# Placeholder to use until bazel supports py_strict_binary.
|
||||
def py_strict_binary(name, **kwargs):
|
||||
native.py_binary(name = name, **kwargs)
|
||||
|
||||
# Placeholder to use until bazel supports py_strict_library.
|
||||
def py_strict_library(name, **kwargs):
|
||||
native.py_library(name = name, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user