diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 8d1e0966ff7..5a63e66eda6 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -225,7 +225,8 @@ Status FlattenSignature(const StructuredValue& signature, } const SavedObject* FindNodeAtPath(StringPiece path, - const SavedObjectGraph& object_graph) { + const SavedObjectGraph& object_graph, + int* node_id) { const auto& nodes = object_graph.nodes(); if (nodes.empty()) { return nullptr; @@ -245,6 +246,9 @@ const SavedObject* FindNodeAtPath(StringPiece path, if (child_node_iter == current_node->children().end()) { return nullptr; } + if (node_id) { + *node_id = child_node_iter->node_id(); + } current_node = &nodes.Get(child_node_iter->node_id()); } diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index e82ec1bd104..dbcd6f2ac6d 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -78,9 +78,11 @@ Status FlattenSignature(const StructuredValue& signature, // Find the SavedObject in `object_graph` at location `path`. `path` must be // a dot-delimited string of object names relative to the root object. If no // object is found, returns nullptr. Callers must ensure `object_graph` -// outlives the returned pointer. +// outlives the returned pointer. If not `nullptr`, `node_id` will contain the +// index of the returned object in the `SavedObjectGraph.nodes` array. const SavedObject* FindNodeAtPath(StringPiece path, - const SavedObjectGraph& object_graph); + const SavedObjectGraph& object_graph, + int* node_id = nullptr); // Maps each node in `graphdef` to its corresponding Attribute Map. // Callers must ensure that `graphdef` outlives the returned map. diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index 143257b01d5..86490625d43 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -268,6 +268,12 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle, } const std::string& checkpoint_key = attribute->checkpoint_key(); + if (!bundle->variable_reader()->Contains(checkpoint_key)) { + LOG(WARNING) << "No checkpoint entry found for " << checkpoint_key + << ". Variable will be uninitialized."; + return Status(); + } + std::string variables_path_prefix = io::JoinPath(directory, kSavedModelVariablesDirectory, kSavedModelVariablesFilename); @@ -325,6 +331,31 @@ std::vector TFSavedModelAPI::ListFunctions() { return result; } +Status TFSavedModelAPI::GetVariable(const std::string& variable_path, + Variable** variable) { + int node_id; + const SavedObject* object = internal::FindNodeAtPath( + variable_path, bundle_.saved_object_graph(), &node_id); + if (object == nullptr) { + return errors::NotFound("No saved object found at path ", variable_path); + } + + if (object->kind_case() == SavedObject::kVariable) { + auto iter = revived_objects_.find(node_id); + if (iter == revived_objects_.end()) { + return errors::Internal("Variable ", variable_path, + " was not properly revived."); + } + *variable = static_cast(iter->second.get()); + return Status(); + } + + *variable = nullptr; + return errors::InvalidArgument( + variable_path, " is not a path to a Variable (kind=", object->kind_case(), + ")"); +} + TFSavedModelAPI::TFSavedModelAPI( const std::string& directory, SavedModelV2Bundle bundle, std::unordered_map> diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h index fd07c09474b..d108b4071e9 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/cc/saved_model/bundle_v2.h" @@ -68,6 +69,8 @@ class TFSavedModelAPI : public SavedModelAPI { ~TFSavedModelAPI() override = default; + Status GetVariable(const std::string& variable_path, Variable** variable); + private: TFSavedModelAPI( const std::string& directory, SavedModelV2Bundle bundle, diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index c0d121a4aee..5d6d23b403f 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -245,14 +245,17 @@ tf_cc_test( "saved_model_api_test.cc", ], data = [ + "//tensorflow/c/experimental/saved_model/internal/testdata:saved_models", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], deps = [ + ":saved_model_api_type", "//tensorflow/c:tf_status", "//tensorflow/c:tf_tensor", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_test_util", + "//tensorflow/c/experimental/saved_model/core:tf_saved_model_api", "//tensorflow/c/experimental/saved_model/public:concrete_function", "//tensorflow/c/experimental/saved_model/public:saved_model_api", "//tensorflow/core:lib", diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index 86754b32c0c..a55f232795b 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" @@ -194,6 +196,49 @@ TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) { TFE_DeleteContext(ctx); } +TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), + "c/experimental/saved_model/internal/testdata/UninitializedVariable"); + + TF_SavedModel* saved_model = + TF_LoadSavedModel(model_dir.c_str(), ctx, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + tensorflow::TFSavedModelAPI* model_api = + tensorflow::down_cast( + tensorflow::unwrap(saved_model)); + tensorflow::Variable* uninitialized_variable; + ASSERT_EQ(tensorflow::Status::OK(), + model_api->GetVariable("uninitialized_variable", + &uninitialized_variable)); + ASSERT_EQ(tensorflow::DT_FLOAT, uninitialized_variable->dtype()); + + ASSERT_EQ(tensorflow::Status::OK(), + model_api->GetVariable("sub_module.uninitialized_variable", + &uninitialized_variable)); + ASSERT_EQ(tensorflow::DT_INT64, uninitialized_variable->dtype()); + + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest, ::testing::Bool()); diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD new file mode 100644 index 00000000000..f7b4b1de677 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD @@ -0,0 +1,36 @@ +load("//tensorflow:tensorflow.bzl", "py_strict_binary") + +package( + licenses = ["notice"], # Apache 2.0 +) + +# Run this binary manually, with an argument pointing to the testdata/ +# directory, to generate the test files used by the filegroup rule below. +py_strict_binary( + name = "gen_saved_models", + srcs = ["gen_saved_models.py"], + python_version = "PY3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:variables", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/module", + "//tensorflow/python/saved_model", + "//tensorflow/python/saved_model:save_options", + ], +) + +# Files generated by the binary above. +filegroup( + name = "saved_models", + srcs = glob([ + "UninitializedVariable/**", + ]), + visibility = [ + "//tensorflow/c/experimental/saved_model/internal:__pkg__", + ], +) diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb new file mode 100644 index 00000000000..81ce8fe662b Binary files /dev/null and b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb differ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.data-00000-of-00001 b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..b68ed0f5a6e Binary files /dev/null and b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.index b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.index new file mode 100644 index 00000000000..ed07d0514c7 Binary files /dev/null and b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/variables/variables.index differ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py b/tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py new file mode 100644 index 00000000000..f2a8bd5a9a4 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py @@ -0,0 +1,84 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Creates saved models used for testing. + +This executable should be run with an argument pointing to the testdata/ folder +in this directory. It will re-generate the saved models that are used for +testing. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import google_type_annotations +from __future__ import print_function + +import os + +from tensorflow.python.compat import v2_compat + +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_spec +from tensorflow.python.module import module +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import app +from tensorflow.python.saved_model import saved_model + + +def _gen_uninitialized_variable(base_dir): + """Generates a saved model with an uninitialized variable.""" + + class SubModule(module.Module): + """A module with an UninitializedVariable.""" + + def __init__(self): + self.uninitialized_variable = resource_variable_ops.UninitializedVariable( + name="uninitialized_variable", dtype=dtypes.int64) + + class Module(module.Module): + """A module with an UninitializedVariable.""" + + def __init__(self): + super(Module, self).__init__() + self.sub_module = SubModule() + self.initialized_variable = variables.Variable( + 1.0, name="initialized_variable") + # An UninitializedVariable with the same name as the variable in the + # SubModule, but with a different type. + self.uninitialized_variable = resource_variable_ops.UninitializedVariable( + name="uninitialized_variable", dtype=dtypes.float32) + + @def_function.function( + input_signature=[tensor_spec.TensorSpec((), dtypes.float32)]) + def compute(self, value): + return self.initialized_variable + value + + to_save = Module() + saved_model.save( + to_save, export_dir=os.path.join(base_dir, "UninitializedVariable")) + + +def main(args): + if len(args) != 2: + raise app.UsageError("Expected one argument (base_dir).") + _, base_dir = args + _gen_uninitialized_variable(base_dir) + + +if __name__ == "__main__": + v2_compat.enable_v2_behavior() + app.run(main) diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index d74d190f37e..f3d78881429 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -47,6 +47,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables +from tensorflow.python.saved_model import load from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import save @@ -611,6 +612,31 @@ class SaveTest(test.TestCase, parameterized.TestCase): experimental_variable_policy=save_options.VariablePolicy .EXPAND_DISTRIBUTED_VARIABLES)) + def test_save_uninitialized_variable(self): + root = tracking.AutoTrackable() + root.uninitialized_variable = resource_variable_ops.UninitializedVariable( + name="uninitialized_variable", dtype=dtypes.float32) + root.initialized_variable = variables.Variable( + 1.0, name="initialized_variable") + + # TODO(b/149594077): Python loading does not work now partly because it + # shouldn't, as the public API and semantics of uninitialized variables + # are not properly defined, and officially supporting loading would end up + # defining semantics "by usage." We should only allow loading once the API + # is made official. + export_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(root, export_dir) + with self.assertRaisesRegex(FileNotFoundError, + "Key uninitialized_variable"): + load.load(export_dir) + with ops.Graph().as_default(), session_lib.Session() as session: + # The final ValueError here (with "no variables to save") is confusing, + # but errors upstream give the user the correct information (a + # NotFoundError stating that the uninitalized_variable was not found in + # the checkpoint). + with self.assertRaises(ValueError): + loader.load(session, [tag_constants.SERVING], export_dir) + class VariablePolicyEnumTest(test.TestCase): @@ -820,8 +846,7 @@ class AssetTests(test.TestCase): key_index=lookup_ops.TextFileIndex.WHOLE_LINE, value_dtype=dtypes.int64, value_index=lookup_ops.TextFileIndex.LINE_NUMBER) - table = lookup_ops.HashTable( - initializer, default_value=-1) + table = lookup_ops.HashTable(initializer, default_value=-1) root.table_user = def_function.function( table.lookup, input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) diff --git a/tensorflow/python/training/saving/functional_saver.py b/tensorflow/python/training/saving/functional_saver.py index c973c43009c..1fa0da0f33c 100644 --- a/tensorflow/python/training/saving/functional_saver.py +++ b/tensorflow/python/training/saving/functional_saver.py @@ -72,9 +72,14 @@ class _SingleDeviceSaver(object): tensor_slices = [] for saveable in self._saveable_objects: for spec in saveable.specs: - tensor_names.append(spec.name) - tensors.append(spec.tensor) - tensor_slices.append(spec.slice_spec) + tensor = spec.tensor + # A tensor value of `None` indicates that this SaveableObject gets + # recorded in the object graph, but that no value is saved in the + # checkpoint. + if tensor is not None: + tensor_names.append(spec.name) + tensors.append(tensor) + tensor_slices.append(spec.slice_spec) save_device = options.experimental_io_device or "cpu:0" with ops.device(save_device): return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors) diff --git a/tensorflow/python/training/saving/saveable_object.py b/tensorflow/python/training/saving/saveable_object.py index 54f1d1fb237..4c67fc4feb1 100644 --- a/tensorflow/python/training/saving/saveable_object.py +++ b/tensorflow/python/training/saving/saveable_object.py @@ -26,6 +26,7 @@ class SaveSpec(object): Args: tensor: the tensor to save or callable that produces a tensor to save. + If the value is `None`, the `SaveSpec` is ignored. slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`. name: the name to save the tensor under. dtype: The data type of the Tensor. Required if `tensor` is callable. diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py index d4af3fb7956..2b235e70117 100644 --- a/tensorflow/python/training/saving/saveable_object_util.py +++ b/tensorflow/python/training/saving/saveable_object_util.py @@ -99,11 +99,16 @@ class ResourceVariableSaveable(saveable_object.SaveableObject): def _read_variable_closure(v): def f(): with ops.device(v.device): + if context.executing_eagerly() and not v.is_initialized(): + # A SaveSpec tensor value of `None` indicates that the variable is + # uninitialized. + return None x = v.read_value() # To allow variables placed on non-CPU devices to be checkpointed, # we copy them to CPU on the same machine first. with ops.device("/device:CPU:0"): return array_ops.identity(x) + return f self.handle_op = var.handle @@ -177,8 +182,8 @@ def saveable_objects_for_op(op, name): yield ReferenceVariableSaveable( variable, variable._save_slice_info.spec, name) else: - yield ResourceVariableSaveable( - variable, variable._save_slice_info.spec, name) + yield ResourceVariableSaveable(variable, variable._save_slice_info.spec, + name) # pylint: enable=protected-access elif isinstance(op, trackable.Trackable) and not isinstance( op, variables.Variable): @@ -196,12 +201,10 @@ def saveable_objects_for_op(op, name): else: # A variable or tensor. if isinstance(op, resource_variable_ops.BaseResourceVariable): - # pylint: disable=protected-access - if op._in_graph_mode: - variable = op._graph_element + if op._in_graph_mode: # pylint: disable=protected-access + variable = op._graph_element # pylint: disable=protected-access else: variable = op - # pylint: enable=protected-access yield ResourceVariableSaveable(variable, "", name) else: if context.executing_eagerly(): @@ -217,8 +220,7 @@ def saveable_objects_for_op(op, name): "AutoReloadVariable"]: yield ReferenceVariableSaveable(variable, "", name) else: - yield ResourceVariableSaveable( - variable, "", name) + yield ResourceVariableSaveable(variable, "", name) def op_list_to_dict(op_list, convert_variable_to_tensor=True): diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 132ba3b6caa..a695db365aa 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1897,6 +1897,10 @@ register_extension_info( label_regex_for_dep = "{extension_name}", ) +# Placeholder to use until bazel supports py_strict_binary. +def py_strict_binary(name, **kwargs): + native.py_binary(name = name, **kwargs) + # Placeholder to use until bazel supports py_strict_library. def py_strict_library(name, **kwargs): native.py_library(name = name, **kwargs)