From e143ada4a6c1a42ea90873dc4582160d1a50f69f Mon Sep 17 00:00:00 2001 From: Brian Zhao <bmzhao@google.com> Date: Tue, 22 Sep 2020 21:45:45 -0700 Subject: [PATCH] Refactoring the guts of SavedModel loading to experimentally support loading resources that do not override _gather_saveables_for_checkpoint. This allows us to support simple resources like StaticHashTable. Note that additional logic related to checkpoints will need to be added to support https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c. PiperOrigin-RevId: 333221978 Change-Id: Ib724b6cee9a57bf1b3ee98d2a8eaf9f394a8d64d --- .../c/experimental/saved_model/core/BUILD | 9 +- .../saved_model/core/revived_types/BUILD | 107 +++++ .../partially_revived_objects.cc | 388 ++++++++++++++++++ .../revived_types/partially_revived_objects.h | 54 +++ .../core/revived_types/restored_resource.cc | 76 ++++ .../core/revived_types/restored_resource.h | 87 ++++ .../restored_resource_revival_state.h | 38 ++ .../core/revived_types/revived_objects.h | 51 +++ .../tf_concrete_function_revival_state.h | 61 +++ .../tf_signature_def_function_revival_state.h | 55 +++ .../saved_model/core/saved_model_utils.cc | 261 +++++++++++- .../saved_model/core/saved_model_utils.h | 17 +- .../saved_model/core/tf_saved_model_api.cc | 234 +++-------- .../saved_model/core/tf_saved_model_api.h | 14 +- .../internal/saved_model_api_test.cc | 137 +++++++ tensorflow/cc/saved_model/BUILD | 10 + .../assets/static_hashtable_asset.txt | 4 + .../StaticHashTableModule/saved_model.pb | Bin 0 -> 14204 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 88 bytes .../variables/variables.index | Bin 0 -> 144 bytes .../testdata/generate_saved_models.py | 22 + .../testdata/static_hashtable_asset.txt | 4 + 22 files changed, 1418 insertions(+), 211 deletions(-) create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h create mode 100644 tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt create mode 100644 tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb create mode 100644 tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001 create mode 100644 tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index create mode 100644 tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index d6c3613d194..bc532506a46 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -66,8 +66,13 @@ cc_library( "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/experimental/saved_model/core/revived_types:asset", "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects", + "//tensorflow/c/experimental/saved_model/core/revived_types:restored_resource_revival_state", "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function_revival_state", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_signature_def_function_revival_state", "//tensorflow/c/experimental/saved_model/core/revived_types:variable", + "//tensorflow/cc/saved_model:loader_util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -147,12 +152,14 @@ cc_library( "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/experimental/saved_model/core/ops:restore_ops", "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:flat_tensor_function", + "//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects", + "//tensorflow/c/experimental/saved_model/core/revived_types:revived_objects", "//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible", "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", "//tensorflow/c/experimental/saved_model/core/revived_types:variable", "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/cc/saved_model:constants", - "//tensorflow/cc/saved_model:loader_util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 1205f12b948..eaccc2fac69 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -69,6 +69,84 @@ cc_library( ], ) +cc_library( + name = "partially_revived_objects", + srcs = [ + "partially_revived_objects.cc", + ], + hdrs = [ + "partially_revived_objects.h", + ], + deps = [ + ":asset", + ":constant", + ":restored_resource", + ":restored_resource_revival_state", + ":revived_objects", + ":tf_concrete_function", + ":tf_concrete_function_revival_state", + ":tf_signature_def_function", + ":tf_signature_def_function_revival_state", + ":variable", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "restored_resource", + srcs = [ + "restored_resource.cc", + ], + hdrs = [ + "restored_resource.h", + ], + deps = [ + ":tensorhandle_convertible", + ":tf_concrete_function", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "restored_resource_revival_state", + hdrs = [ + "restored_resource_revival_state.h", + ], + deps = [ + ":tf_concrete_function_revival_state", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + ], +) + +cc_library( + name = "revived_objects", + hdrs = [ + "revived_objects.h", + ], + deps = [ + ":asset", + ":constant", + ":restored_resource", + ":tf_concrete_function", + ":tf_signature_def_function", + ":variable", + "//tensorflow/core:lib", + ], +) + cc_library( name = "variable", srcs = [ @@ -123,6 +201,21 @@ cc_library( ], ) +cc_library( + name = "tf_concrete_function_revival_state", + hdrs = [ + "tf_concrete_function_revival_state.h", + ], + deps = [ + ":tf_concrete_function", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "tf_signature_def_function", srcs = [ @@ -145,3 +238,17 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "tf_signature_def_function_revival_state", + hdrs = [ + "tf_signature_def_function_revival_state.h", + ], + deps = [ + ":tf_signature_def_function", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc new file mode 100644 index 00000000000..5cc06e6c54f --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc @@ -0,0 +1,388 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" + +#include <memory> +#include <utility> + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +namespace { + +Status AssertAllCreateResourceFunctionsHaveNoCaptures( + const PartiallyRevivedObjects& objects) { + for (const auto& id_and_resource : objects.restored_resources) { + int node_id = id_and_resource.first; + const RestoredResourceRevivalState& resource = id_and_resource.second; + const TFConcreteFunctionRevivalState* create_resource_fn = + resource.create_resource; + if (create_resource_fn == nullptr) { + return errors::FailedPrecondition( + "Resource at node ", node_id, + " did not have a create_resource() function"); + } + const SavedConcreteFunction* saved_create_resource_fn = + create_resource_fn->saved_concrete_func; + if (!saved_create_resource_fn->bound_inputs().empty()) { + // TODO(b/124045874): Support loading resource functions via a top sort + return errors::Unimplemented( + "Create Resource functions with captures are currently unsupported."); + } + } + return Status(); +} + +// Retrieves the TensorHandle associated with `node_id` from `obj_graph`, and +// set `*handle` to point to it. +Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + ImmediateExecutionTensorHandle** handle) { + const SavedObject& node = obj_graph.nodes(node_id); + SavedObject::KindCase kind = node.kind_case(); + switch (kind) { + case SavedObject::kVariable: { + const auto& variables_iter = objects.variables.find(node_id); + if (variables_iter == objects.variables.end()) { + return errors::FailedPrecondition( + "Tried to convert node id ", node_id, + " of type variable to tensor but the variable wasn't initialized"); + } + *handle = variables_iter->second->handle(); + return Status(); + } + case SavedObject::kConstant: { + const auto& constants_iter = objects.constants.find(node_id); + if (constants_iter == objects.constants.end()) { + return errors::FailedPrecondition("Tried to convert node id ", node_id, + " of type constant to tensor but the " + "constant wasn't initialized"); + } + *handle = constants_iter->second->handle(); + return Status(); + } + case SavedObject::kAsset: { + const auto& assets_iter = objects.assets.find(node_id); + if (assets_iter == objects.assets.end()) { + return errors::FailedPrecondition( + "Tried to convert node id ", node_id, + " of type asset to tensor but the asset wasn't initialized"); + } + *handle = assets_iter->second->handle(); + return Status(); + } + case SavedObject::kResource: { + const auto& resource_iter = objects.restored_resources.find(node_id); + if (resource_iter == objects.restored_resources.end()) { + return errors::FailedPrecondition( + "Tried to convert node id ", node_id, + " of type Resource to tensor but the Resource wasn't initialized"); + } + const RestoredResourceRevivalState& resource = resource_iter->second; + if (resource.resource_handle == nullptr) { + return errors::FailedPrecondition( + "Resource with node id ", node_id, + " should have its resource_handle created, but was nullptr."); + } + *handle = resource.resource_handle.get(); + return Status(); + } + default: { + return errors::FailedPrecondition( + "Only objects of type variable, constant, asset, and resources have " + "capturable tensorhandles. Encountered object of kind ", + node.kind_case(), " at node id: ", node_id); + } + } +} + +// This function finds the necessary captures, then forwards to the builder +// method +Status CreateConcreteFunction(ImmediateExecutionContext* ctx, + const TFConcreteFunctionRevivalState& builder, + const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + std::unique_ptr<TFConcreteFunction>* out) { + const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs(); + std::vector<ImmediateExecutionTensorHandle*> captures; + captures.reserve(capture_node_ids.size()); + for (int capture_node_id : capture_node_ids) { + ImmediateExecutionTensorHandle* capture_handle; + TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects, + &capture_handle)); + captures.push_back(capture_handle); + } + // TODO(bmzhao): Create Metadata here + return TFConcreteFunction::Create(/*function_def=*/builder.fdef, + /*captures=*/std::move(captures), + /*metadata=*/{}, + /*ctx=*/ctx, + /*out=*/out); +} + +Status CreateSignatureDefFunction( + ImmediateExecutionContext* ctx, + const TFSignatureDefFunctionRevivalState& builder, + const SavedObjectGraph& obj_graph, const PartiallyRevivedObjects& objects, + std::unique_ptr<TFSignatureDefFunction>* out) { + const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs(); + std::vector<ImmediateExecutionTensorHandle*> captures; + captures.reserve(capture_node_ids.size()); + for (int capture_node_id : capture_node_ids) { + ImmediateExecutionTensorHandle* capture_handle; + TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects, + &capture_handle)); + captures.push_back(capture_handle); + } + // TODO(bmzhao): Create Metadata here + return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef, + /*captures=*/std::move(captures), + /*metadata=*/{}, + /*ctx=*/ctx, + /*out=*/out); +} + +Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + RevivedObjects* revived) { + for (const auto& id_and_resource : objects.restored_resources) { + const RestoredResourceRevivalState& resource = id_and_resource.second; + const TFConcreteFunctionRevivalState* create_resource_fn = + resource.create_resource; + + const SavedConcreteFunction* saved_create_resource_fn = + create_resource_fn->saved_concrete_func; + if (!saved_create_resource_fn->bound_inputs().empty()) { + // TODO(b/124045874): Load resource functions via a topological sort + return errors::Unimplemented( + "Create Resource functions with captures are currently unsupported."); + } + std::unique_ptr<TFConcreteFunction> out; + TF_RETURN_IF_ERROR(CreateConcreteFunction(ctx, *create_resource_fn, + obj_graph, objects, &out)); + revived->concrete_functions[create_resource_fn->node_id] = std::move(out); + } + return Status(); +} + +Status InitializeAllFunctions(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + RevivedObjects* revived) { + gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>* destination_func_map = + &revived->concrete_functions; + gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>* + destination_sig_map = &revived->signature_def_functions; + + for (const auto& id_and_func : objects.concrete_functions) { + int node_id = id_and_func.first; + const TFConcreteFunctionRevivalState& func = id_and_func.second; + + if (destination_func_map->find(node_id) != destination_func_map->end()) { + // The function has already been initialized in the destination_map, + // so we can skip this node. This can occur because we initialize + // CreateResource functions before calling this function. + continue; + } + + std::unique_ptr<TFConcreteFunction> out; + TF_RETURN_IF_ERROR( + CreateConcreteFunction(ctx, func, obj_graph, objects, &out)); + (*destination_func_map)[node_id] = std::move(out); + } + + for (const auto& id_and_func : objects.signature_def_functions) { + int node_id = id_and_func.first; + const TFSignatureDefFunctionRevivalState& func = id_and_func.second; + + if (destination_sig_map->find(node_id) != destination_sig_map->end()) { + continue; + } + + std::unique_ptr<TFSignatureDefFunction> out; + TF_RETURN_IF_ERROR( + CreateSignatureDefFunction(ctx, func, obj_graph, objects, &out)); + (*destination_sig_map)[node_id] = std::move(out); + } + + return Status(); +} + +Status CreateAllResourceHandles(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + PartiallyRevivedObjects* objects, + RevivedObjects* revived) { + for (auto& id_and_resource : objects->restored_resources) { + RestoredResourceRevivalState& resource = id_and_resource.second; + int create_resource_fn_node = resource.create_resource->node_id; + const gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>& + revived_functions = revived->concrete_functions; + + const auto& revived_functions_iter = + revived_functions.find(create_resource_fn_node); + if (revived_functions_iter == revived_functions.end()) { + return errors::FailedPrecondition( + "ConcreteFunction at node ", create_resource_fn_node, + " should have been initialized prior to being called."); + } + const TFConcreteFunction& create_resource_fn = + *revived_functions_iter->second; + ImmediateOpPtr function_op; + TF_RETURN_IF_ERROR(create_resource_fn.MakeCallOp({}, &function_op)); + TF_RETURN_IF_ERROR(function_op->SetDeviceName(resource.device.c_str())); + + AbstractTensorHandle* resource_handle = nullptr; + int num_retvals = 1; + TF_RETURN_IF_ERROR(function_op->Execute( + absl::MakeSpan(&resource_handle, num_retvals), &num_retvals)); + AbstractTensorHandlePtr owned_resource_handle(resource_handle); + if (!tensorflow::isa<ImmediateExecutionTensorHandle>( + owned_resource_handle.get())) { + return errors::Internal("Unexpected tensor handle kind."); + } + ImmediateTensorHandlePtr result( + reinterpret_cast<ImmediateExecutionTensorHandle*>( + owned_resource_handle.release())); + resource.resource_handle = std::move(result); + } + return Status(); +} + +// Finds a ConcreteFunction with node id `node` in `objects`, and sets *out to +// point to it. If node doesn't exist in `objects`, out is untouched, and an +// error status is returned. +Status FindConcreteFunction(int node, RevivedObjects* objects, + TFConcreteFunction** out) { + auto func_iter = objects->concrete_functions.find(node); + if (func_iter == objects->concrete_functions.end()) { + return errors::FailedPrecondition( + "Failed to find ConcreteFunction with node id ", node, + " in revived objects"); + } + *out = func_iter->second.get(); + return Status(); +} + +Status BuildResources(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + PartiallyRevivedObjects* objects, + RevivedObjects* revived) { + for (auto& id_and_resource : objects->restored_resources) { + int node_id = id_and_resource.first; + RestoredResourceRevivalState& resource_revival_state = + id_and_resource.second; + + TFConcreteFunction* create_resource = nullptr; + + // Check all the functions associated with the resource have already been + // initialized in `revived` + if (resource_revival_state.create_resource != nullptr) { + TF_RETURN_IF_ERROR( + FindConcreteFunction(resource_revival_state.create_resource->node_id, + revived, &create_resource)); + } + + TFConcreteFunction* initialize = nullptr; + if (resource_revival_state.initialize != nullptr) { + TF_RETURN_IF_ERROR(FindConcreteFunction( + resource_revival_state.initialize->node_id, revived, &initialize)); + } + + TFConcreteFunction* destroy_resource = nullptr; + if (resource_revival_state.destroy_resource != nullptr) { + TF_RETURN_IF_ERROR( + FindConcreteFunction(resource_revival_state.destroy_resource->node_id, + revived, &destroy_resource)); + } + + if (resource_revival_state.resource_handle == nullptr) { + return errors::FailedPrecondition("Resource at node id ", node_id, + " does not have a resource handle."); + } + + revived->restored_resources.emplace( + node_id, RestoredResource( + /*device=*/resource_revival_state.device, + /*create_resource=*/create_resource, + /*initialize=*/initialize, + /*destroy_resource=*/destroy_resource, + /*resource_handle=*/ + std::move(resource_revival_state.resource_handle))); + } + return Status(); +} + +} // namespace + +Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + RevivedObjects* revived) { + // Step 1: We would like to initialize all functions; this requires setting up + // their captured tensorhandles, which may come from variables, assets, + // constants, or resources. The first three are trivial; However, + // tensorhandles that correspond to resources must be created by invoking + // their "create_resource" function. + // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240 + // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233 + // For now, we assert that all create_resource functions must have no + // captures. This aligns with the current behavior in python. + // https://github.com/tensorflow/tensorflow/blob/50eac986bf7a0ad12594e080f083181f277e0b49/tensorflow/python/saved_model/load.py#L152-L155 + // TODO(bmzhao): We should do a topological sort instead. + + // 1a. Make sure all CreateResource functions have no captures. + TF_RETURN_IF_ERROR(AssertAllCreateResourceFunctionsHaveNoCaptures(*this)); + + // 1b. Initialize all CreateResource functions, storing them in `revived` + TF_RETURN_IF_ERROR( + InitializeCreateResourceFunctions(ctx, obj_graph, *this, revived)); + + // 1c. Invoke all "CreateResource" functions and store their ResourceHandles + // https://github.com/tensorflow/tensorflow/blob/3b6b41b68a95dc70c26dc816b29d359bfb88c116/tensorflow/python/training/tracking/tracking.py#L241-L247 + // in *this->resources. + // TODO(bmzhao): Maybe store them separately, not in *this? + TF_RETURN_IF_ERROR(CreateAllResourceHandles(ctx, obj_graph, this, revived)); + + // 2. Initialize all the rest of the functions + TF_RETURN_IF_ERROR(InitializeAllFunctions(ctx, obj_graph, *this, revived)); + + // 3a. Move over all non-function, non-resource objects + revived->variables = std::move(variables); + revived->assets = std::move(assets); + revived->constants = std::move(constants); + + // 3b. Move over resources. + TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived)); + + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h new file mode 100644 index 00000000000..cce52748a1d --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ + +#include <memory> + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// Container for objects during the revival step in SavedModel's loading. +// Notably, resources and functions can be in a state where they reference +// other resources/functions that have not been constructed yet. We collect +// *all* objects in a partially valid state here, then properly initialize +// resources and functions. +struct PartiallyRevivedObjects { + gtl::FlatMap<int, std::unique_ptr<Variable>> variables; + gtl::FlatMap<int, std::unique_ptr<Asset>> assets; + gtl::FlatMap<int, std::unique_ptr<Constant>> constants; + gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions; + gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions; + gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources; + + Status Build(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, RevivedObjects* revived); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc new file mode 100644 index 00000000000..47860ce8b39 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +Status ExecuteNoArgDummyReturnFunction(TFConcreteFunction* func) { + ImmediateOpPtr function_op; + TF_RETURN_IF_ERROR(func->MakeCallOp({}, &function_op)); + + AbstractTensorHandle* dummy_output = nullptr; + int num_retvals = 1; + TF_RETURN_IF_ERROR(function_op->Execute( + absl::MakeSpan(&dummy_output, num_retvals), &num_retvals)); + AbstractTensorHandlePtr owned_dummy_output(dummy_output); + return Status(); +} + +} // namespace + +RestoredResource::RestoredResource(const std::string& device, + TFConcreteFunction* create_resource, + TFConcreteFunction* initialize, + TFConcreteFunction* destroy_resource, + ImmediateTensorHandlePtr resource_handle) + : TensorHandleConvertible(std::move(resource_handle)), + device_(device), + create_resource_(create_resource), + initialize_(initialize), + destroy_resource_(destroy_resource) {} + +Status RestoredResource::Initialize() const { + return ExecuteNoArgDummyReturnFunction(initialize_); +} + +RestoredResource::~RestoredResource() { + // Note(bmzhao): SavedModels saved before + // https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3 + // did not have their destroy_resource function saved, meaning they will + // leak resources. + if (destroy_resource_ != nullptr) { + Status status = ExecuteNoArgDummyReturnFunction(destroy_resource_); + if (!status.ok()) { + LOG(WARNING) + << "Failed executing destroy_resource function for RestoredResource: " + << status.error_message(); + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h new file mode 100644 index 00000000000..7adbd563a6b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ + +#include <memory> +#include <string> + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" + +namespace tensorflow { + +// RestoredResource represents a TF2 "Resource" object loaded from a savedmodel, +// analogous to the Python _RestoredResource object: +// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/saved_model/load.py#L481 +// TF2 resource objects typically extend TrackableResource: +// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/training/tracking/tracking.py#L285 +// and are expected to implement "_create_resource", "_initialize", and +// "_destroy_resource" functions: +// https://github.com/tensorflow/tensorflow/blob/139ba9c5284799beafdd1d7f895127cf00e7c48f/tensorflow/python/training/tracking/tracking.py#L262-L281 +class RestoredResource : TensorHandleConvertible { + public: + // Note(bmzhao): RestoredResource stores non-owning pointers to its associated + // functions because SavedModel internally owns all functions and objects in + // the RevivedObjects struct (which owns all functions). One alternative would + // be to have RevivedObjects store shared_ptr<TFConcreteFunction> instead, and + // change RestoredResource's constructor take shared_ptr<TFConcreteFunction>. + // To keep things simple, I've stuck to raw pointers for now. + // + // Params: + // device - The device string associated with the SavedResource + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_object_graph.proto#L182 + // Conceptually, this is the same device used in CapturableResource: + // https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L222-L225 + // Implementation-wise, it is device used when invoking the + // create_resource function to produce the resource_handle + // associated with the object: + // https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L246-L247 + // create_resource - Non owning pointer to the create_resource function + // associated with this object. Must be NON-NULL. + // initialize - Non owning pointer to the initialize function associated with + // this object. Must be NON-NULL. + // destroy_resource - Non owning pointer to the destroy_resource function + // associated with this object. Ideally this should be + // NON-NULL, but in order to support models saved prior to + // https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3 + // we allow null here. This will, however, leak resources. + RestoredResource(const std::string& device, + TFConcreteFunction* create_resource, + TFConcreteFunction* initialize, + TFConcreteFunction* destroy_resource, + ImmediateTensorHandlePtr resource_handle); + + Status Initialize() const; + + // RestoredResource is movable, but not copyable. + RestoredResource(RestoredResource&& other) = default; + RestoredResource& operator=(RestoredResource&& other) = default; + + ~RestoredResource() override; + + private: + std::string device_; + TFConcreteFunction* create_resource_; + TFConcreteFunction* initialize_; + TFConcreteFunction* destroy_resource_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h new file mode 100644 index 00000000000..48d00308cc1 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ + +#include <string> + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" + +namespace tensorflow { + +// All "Resources" should have these 3 saved functions: +// https://github.com/tensorflow/tensorflow/blob/86dc281333d7d277ddc1882f2bca4b17e7ec40e5/tensorflow/python/training/tracking/tracking.py#L277-L281 +struct RestoredResourceRevivalState { + std::string device; + TFConcreteFunctionRevivalState* create_resource = nullptr; + TFConcreteFunctionRevivalState* initialize = nullptr; + TFConcreteFunctionRevivalState* destroy_resource = nullptr; + ImmediateTensorHandlePtr resource_handle = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h b/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h new file mode 100644 index 00000000000..3b8fb0f6a8a --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ + +#include <memory> +#include <unordered_map> + +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { + +// RevivedObjects is mainly used as a container for all the "state" owned by +// SavedModel. It stores all non-"user object" nodes from a SavedModel +// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62) +// in a "fully constructed" state. It is effectively a strongly typed map, where +// each member is a map from the node id in the SavedObjectGraph's nodes +// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29) +// to the revived object of the corresponding type. +struct RevivedObjects { + gtl::FlatMap<int, std::unique_ptr<Variable>> variables; + gtl::FlatMap<int, std::unique_ptr<Asset>> assets; + gtl::FlatMap<int, std::unique_ptr<Constant>> constants; + gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>> concrete_functions; + gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>> + signature_def_functions; + gtl::FlatMap<int, RestoredResource> restored_resources; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h new file mode 100644 index 00000000000..3dd7a6eecc4 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ + +#include <memory> +#include <vector> + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// TFConcreteFunctionRevivalState wraps the state needed for building a +// TF_ConcreteFunction. This is mainly used in PartiallyRevivedObjects, which +// wraps partially constructed Function and Resource objects. +struct TFConcreteFunctionRevivalState { + // Index of the node in the SavedObjectGraph it was loaded from. + int node_id; + + // Pointer to the original functiondef. fdef_ is guaranteed to be + // non-null. + const FunctionDef* fdef; + + // TensorHandle captures for this funtion + std::vector<ImmediateExecutionTensorHandle*> captures; + + // SavedConcreteFunction contains much of the metadata of the expected "types" + // of the inputs and outputs of a function. + // Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null. + const SavedConcreteFunction* saved_concrete_func; + + // This field is only present on TF2 ConcreteFunctions, and is useful for + // determining the original argument *names* of the function, (since the + // "canonicalized_input_signature" may append extra uniquifying integers). + // However, SavedBareConcreteFunctions do not have a FunctionSpec. + // Note(bmzhao): if function_spec_.has_value(), *function_spec_ is guaranteed + // to be non-null. + absl::optional<const FunctionSpec*> function_spec; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h new file mode 100644 index 00000000000..ac1b20e474b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ + +#include <memory> +#include <vector> + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// FunctionBuilder wraps the state needed for building a SignatureDefFunction. +// This is mainly used in PartiallyRevivedObjects, which wraps partially +// constructed Function and Resource objects. +struct TFSignatureDefFunctionRevivalState { + // Index of the node in the SavedObjectGraph it was loaded from. + int node_id = 0; + + // Pointer to the original functiondef. fdef_ is guaranteed to be + // non-null. + const FunctionDef* fdef = nullptr; + + // SavedConcreteFunction contains much of the metadata of the expected "types" + // of the inputs and outputs of a function. + // Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null. + const SavedConcreteFunction* saved_concrete_func = nullptr; + + // The name of the SignatureDef key. + std::string signature_key; + + // TensorHandle captures for this funtion + std::vector<ImmediateExecutionTensorHandle*> captures; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 9ae7778e0cf..ef75f1a382b 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -17,15 +17,22 @@ limitations under the License. #include <algorithm> #include <memory> +#include <unordered_map> +#include <unordered_set> #include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/tf_tensor_internal.h" +#include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" @@ -41,6 +48,83 @@ namespace { using StructuredValueDictEntry = protobuf::MapPair<std::string, StructuredValue>; +// Maps from a Nodedef's name to its corresponding AttrValues, for a given +// Graphdef +using NodeAttrMap = + gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher>; + +// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary +using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, + StringPieceHasher>; + +// Looks up a SavedConstant's associated tensorproto from the NodeAttrMap and +// returns a tensorflow::Constant. +Status ConstantFromSavedConstant( + ImmediateExecutionContext* ctx, + const tensorflow::SavedConstant& saved_constant, + const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) { + const std::string& const_op_name = saved_constant.operation(); + const auto& node_name_and_attrs = node_attr_map.find(const_op_name); + if (node_name_and_attrs == node_attr_map.end()) { + return errors::FailedPrecondition( + "Unable to find Const operation with name'", const_op_name, + "' in SavedModel graphdef"); + } + const AttrValueMap* attrs = node_name_and_attrs->second; + const auto& attr_name_and_value = attrs->find("value"); + if (attr_name_and_value == attrs->end()) { + return errors::FailedPrecondition("Unable to find Const operation '", + const_op_name, "'s value attribute"); + } + const TensorProto& tensor_proto = attr_name_and_value->second.tensor(); + return internal::TensorProtoToConstant(ctx, tensor_proto, output); +} + +// Finds the "signatures" object in the object graph, and fills a mapping of +// each signature's name to the corresponding function's node in the object +// graph. +Status GetSignaturesMap(const SavedObjectGraph& saved_objects, + gtl::FlatMap<std::string, int>* signatures_map) { + if (saved_objects.nodes().empty()) { + return errors::FailedPrecondition("Saved Object Graph was empty."); + } + const SavedObject& root = saved_objects.nodes(0); + const SavedObject* signatures = nullptr; + for (const auto& child : root.children()) { + if (child.local_name() == "signatures") { + if (child.node_id() >= saved_objects.nodes().size()) { + return errors::FailedPrecondition( + "Signature object had child node id ", child.node_id(), + " which exceeds the size of the set of nodes"); + } + signatures = &saved_objects.nodes(child.node_id()); + } + } + + // Some basic sanity checks that this object is actually our "signatures" map + if (signatures == nullptr) { + // This is where the "signatures" attribute is always set: + // https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109 + return errors::FailedPrecondition( + "SavedObjectGraph's root object must have a child 'signatures' object"); + } + if (signatures->kind_case() != SavedObject::kUserObject) { + return errors::FailedPrecondition( + "Signatures must be a SavedObject of type UserObject."); + } + if (signatures->user_object().identifier() != "signature_map") { + // This is where the string comes from: + // https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220 + return errors::FailedPrecondition( + "Signatures SavedObject must have identifier 'signature_map'."); + } + + for (const auto& child : signatures->children()) { + (*signatures_map)[child.local_name()] = child.node_id(); + } + return Status(); +} + // Perform some basic sanity checks on SavedConcreteFunction's input and // output signatures with respect to the corresponding FunctionDef's input // and output args. @@ -99,6 +183,21 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef( return Status(); } +Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) { + // We only allow loading functions that have an annotated input signature, + // which means there is 1:1 correspondence between tf.function + // <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is + // the same restriction that MLIR has: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707 + if (saved_function.concrete_functions_size() != 1) { + return errors::FailedPrecondition( + "Only tf.functions annotated with an input signature are supported " + "by SavedModelAPI. This means that there should only be a single " + "ConcreteFunction per tf.function"); + } + return Status(); +} + } // namespace Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset, @@ -255,21 +354,18 @@ absl::optional<int> FindNodeAtPath(StringPiece path, return node_id; } -std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher> -NodeToAttrMap(const tensorflow::GraphDef& graphdef) { - std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher> - result; +gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap( + const tensorflow::GraphDef& graphdef) { + gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> result; for (const tensorflow::NodeDef& node : graphdef.node()) { result[node.name()] = &node.attr(); } return result; } -std::unordered_map<StringPiece, const tensorflow::FunctionDef*, - StringPieceHasher> +gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher> FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) { - std::unordered_map<StringPiece, const tensorflow::FunctionDef*, - StringPieceHasher> + gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher> result; for (const FunctionDef& function_def : library.function()) { result[function_def.signature().name()] = &function_def; @@ -277,5 +373,154 @@ FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) { return result; } +Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph, + ImmediateExecutionContext* context, + const std::string& directory, + PartiallyRevivedObjects* objects) { + // This is needed to restore "Constant" nodes by looking up their + // "Value" attribute. + NodeAttrMap node_attr_map = NodeToAttrMap(metagraph.graph_def()); + + // These are needed for creating "Assets", by looking up their filenames. + std::vector<AssetFileDef> assets; + TF_RETURN_IF_ERROR(GetAssetFileDefs(metagraph, &assets)); + + // Signatures are needed for determining whether a function is a + // SignatureDefFunction or not. + gtl::FlatMap<std::string, int> signatures_map; + TF_RETURN_IF_ERROR( + GetSignaturesMap(metagraph.object_graph_def(), &signatures_map)); + + gtl::FlatMap<int, std::string> reversed_signatures_map; + reversed_signatures_map.reserve(signatures_map.size()); + for (const auto& signature_key_and_node : signatures_map) { + reversed_signatures_map.emplace(signature_key_and_node.second, + signature_key_and_node.first); + } + + // FunctionDefs are needed to help construct + // TFConcreteFunction/SignatureDefFunctions + const FunctionDefMap function_def_map = + internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library()); + + // Iterate through all the saved objects, restoring objects (if we can) as we + // go. For objects that dependencies on other objects (resources/functions), + // we partially initialize "builders" that correspond to their currently known + // state, and gradually fill them out in subsequent passes. + for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) { + const SavedObject& node = metagraph.object_graph_def().nodes(i); + if (node.kind_case() == SavedObject::kVariable) { + std::unique_ptr<Variable> variable; + TF_RETURN_IF_ERROR( + LoadSavedVariable(context, node.variable(), &variable)); + objects->variables[i] = std::move(variable); + } else if (node.kind_case() == SavedObject::kConstant) { + std::unique_ptr<Constant> constant; + TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(), + node_attr_map, &constant)); + objects->constants[i] = std::move(constant); + } else if (node.kind_case() == SavedObject::kAsset) { + std::unique_ptr<Asset> asset; + TF_RETURN_IF_ERROR( + LoadSavedAsset(context, node.asset(), directory, assets, &asset)); + objects->assets[i] = std::move(asset); + } else if (node.kind_case() == SavedObject::kResource) { + RestoredResourceRevivalState resource_revival_state; + // We'll set the resource's functions in a subsequent pass, once we get + // all functions in a partially revived state. + resource_revival_state.device = node.resource().device(); + objects->restored_resources[i] = std::move(resource_revival_state); + } else if (node.kind_case() == SavedObject::kFunction) { + // Get the SavedFunction node and validate it has a single concrete func. + const SavedFunction& saved_function = node.function(); + TF_RETURN_IF_ERROR(ValidateSingleConcreteFunction(saved_function)); + + // Retrieve related function information. + const std::string& function_name = saved_function.concrete_functions(0); + const FunctionDef* function_def = function_def_map.at(function_name); + const SavedConcreteFunction& saved_concrete_func = + metagraph.object_graph_def().concrete_functions().at(function_name); + const FunctionSpec& function_spec = saved_function.function_spec(); + + // Construct either a SignatureDefFunctionBuilder or a + // ConcreteFunctionBuilder, depending on whether this node was a child + // of the "signatures" attribute from root object. + auto reverse_signature_iter = reversed_signatures_map.find(i); + if (reverse_signature_iter != reversed_signatures_map.end()) { + TFSignatureDefFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + func_revival_state.signature_key = reverse_signature_iter->second; + objects->signature_def_functions[i] = std::move(func_revival_state); + } else { + TFConcreteFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + func_revival_state.function_spec = &function_spec; + objects->concrete_functions[i] = std::move(func_revival_state); + } + } else if (node.kind_case() == SavedObject::kBareConcreteFunction) { + const SavedBareConcreteFunction& bare_cf = node.bare_concrete_function(); + + // Retrieve related function information. + const std::string& function_name = bare_cf.concrete_function_name(); + const FunctionDef* function_def = function_def_map.at(function_name); + const SavedConcreteFunction& saved_concrete_func = + metagraph.object_graph_def().concrete_functions().at(function_name); + + // Check whether this is a SignatureDefFunction, or not. + auto reverse_signature_iter = reversed_signatures_map.find(i); + if (reverse_signature_iter != reversed_signatures_map.end()) { + TFSignatureDefFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + func_revival_state.signature_key = reverse_signature_iter->second; + objects->signature_def_functions[i] = std::move(func_revival_state); + } else { + TFConcreteFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + objects->concrete_functions[i] = std::move(func_revival_state); + } + } + } + + // Now that we've partially restored all functions, we can have resources + // point to them + for (auto& node_and_resource_revival_state : objects->restored_resources) { + int node_id = node_and_resource_revival_state.first; + const SavedObjectGraph& obj_graph = metagraph.object_graph_def(); + const SavedObject& node = obj_graph.nodes(node_id); + RestoredResourceRevivalState& resource = + node_and_resource_revival_state.second; + for (const TrackableObjectGraph::TrackableObject::ObjectReference& child : + node.children()) { + int child_node_id = child.node_id(); + // Note(bmzhao): The expected functions saved by a resource object are: + // "_create_resource", "_initialize", and "_destroy_resource". + // https://github.com/tensorflow/tensorflow/blob/ad66f588c1666ade8051feb42811fa27b285271c/tensorflow/python/training/tracking/tracking.py#L277-L281 + if (child.local_name() == "_create_resource" && + obj_graph.nodes(child.node_id()).kind_case() == + SavedObject::kFunction) { + resource.create_resource = &objects->concrete_functions[child_node_id]; + } else if (child.local_name() == "_initialize" && + obj_graph.nodes(child.node_id()).kind_case() == + SavedObject::kFunction) { + resource.initialize = &objects->concrete_functions[child_node_id]; + } else if (child.local_name() == "_destroy_resource" && + obj_graph.nodes(child.node_id()).kind_case() == + SavedObject::kFunction) { + resource.destroy_resource = &objects->concrete_functions[child_node_id]; + } + } + } + + return Status(); +} + } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index 073d7072e9f..db45e28087f 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -27,10 +27,12 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" @@ -84,15 +86,22 @@ absl::optional<int> FindNodeAtPath(StringPiece path, // Maps each node in `graphdef` to its corresponding Attribute Map. // Callers must ensure that `graphdef` outlives the returned map. -std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher> -NodeToAttrMap(const tensorflow::GraphDef& graphdef); +gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap( + const tensorflow::GraphDef& graphdef); // Maps the name of each FunctionDef in `library` to its corresponding // FunctionDef. Callers must ensure `library` outlives the returned map. -std::unordered_map<StringPiece, const tensorflow::FunctionDef*, - StringPieceHasher> +gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher> FunctionNameToFunctionDefMap(const FunctionDefLibrary& library); +// Walks through the SavedObjectGraph in metagraph, and restores all nodes +// (except "UserDefinedObjects") with their corresponding type in +// "PartiallyRevivedObjects". +Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph, + ImmediateExecutionContext* context, + const std::string& directory, + PartiallyRevivedObjects* objects); + } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index 0662482e538..6386d0dbb79 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -17,7 +17,6 @@ limitations under the License. #include <memory> #include <string> -#include <unordered_map> #include <unordered_set> #include <vector> @@ -30,6 +29,9 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" @@ -37,7 +39,6 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/constants.h" -#include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" @@ -46,6 +47,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/errors.h" @@ -62,142 +64,15 @@ limitations under the License. namespace tensorflow { // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary -using FunctionDefMap = - std::unordered_map<StringPiece, const tensorflow::FunctionDef*, - StringPieceHasher>; - -// Maps from a Nodedef's name to its corresponding AttrValues, for a given -// Graphdef -using NodeAttrMap = - std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>; - -// Maps from Node ID to an "Revived Object" implementing -// "TensorHandleConvertible" -using RevivedObjectMap = - std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>; +using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, + StringPieceHasher>; // Maps from a functiondef's name to the corresponding "TFConcreteFunction" -using ConcreteFunctionMap = - std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>; +using FlatTensorFunctionMap = + gtl::FlatMap<std::string, std::unique_ptr<FlatTensorFunction>>; namespace { -Status ConstantFromSavedConstant( - ImmediateExecutionContext* ctx, - const tensorflow::SavedConstant& saved_constant, - const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) { - const std::string& const_op_name = saved_constant.operation(); - const auto& node_name_and_attrs = node_attr_map.find(const_op_name); - if (node_name_and_attrs == node_attr_map.end()) { - return errors::FailedPrecondition( - "Unable to find Const operation with name'", const_op_name, - "' in SavedModel graphdef"); - } - const AttrValueMap* attrs = node_name_and_attrs->second; - const auto& attr_name_and_value = attrs->find("value"); - if (attr_name_and_value == attrs->end()) { - return errors::FailedPrecondition("Unable to find Const operation '", - const_op_name, "'s value attribute"); - } - const TensorProto& tensor_proto = attr_name_and_value->second.tensor(); - return internal::TensorProtoToConstant(ctx, tensor_proto, output); -} - -// Restores all non-function objects in the SavedModel's object graph. -// This function walks through the metagraph's saved object graph, and -// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and -// SavedResources. These are returned via the `out` parameter. -Status ReviveObjects( - const MetaGraphDef& metagraph, ImmediateExecutionContext* context, - const std::string& directory, - std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>* - revived_objects) { - // This is needed to restore "Constant" nodes by looking up their - // "Value" attribute. - NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def()); - - // These are needed for creating "Assets", by looking up their filenames. - std::vector<AssetFileDef> assets; - TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(metagraph, &assets)); - - // Iterate through all the saved objects, restoring objects as we go. - // We don't recreate functions until all other objects have been created. - for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) { - const SavedObject& node = metagraph.object_graph_def().nodes(i); - if (node.kind_case() == SavedObject::kVariable) { - std::unique_ptr<Variable> variable; - TF_RETURN_IF_ERROR( - internal::LoadSavedVariable(context, node.variable(), &variable)); - (*revived_objects)[i] = std::move(variable); - } else if (node.kind_case() == SavedObject::kConstant) { - std::unique_ptr<Constant> constant; - TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(), - node_attr_map, &constant)); - (*revived_objects)[i] = std::move(constant); - } else if (node.kind_case() == SavedObject::kAsset) { - std::unique_ptr<Asset> asset; - TF_RETURN_IF_ERROR(internal::LoadSavedAsset(context, node.asset(), - directory, assets, &asset)); - (*revived_objects)[i] = std::move(asset); - } else if (node.kind_case() == SavedObject::kResource) { - // TODO(bmzhao): Figure out how resource loading works and implement it - return errors::Unimplemented( - "SavedResource loading is not implemented yet"); - } - } - return Status(); -} - -Status ReviveFunctions(const MetaGraphDef& metagraph, - const RevivedObjectMap& revived_objects, - ImmediateExecutionContext* context, - ConcreteFunctionMap* restored_functions) { - const FunctionDefMap function_def_map = - internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library()); - - // Iterate through all objects, only examining functions. - for (const SavedObject& node : metagraph.object_graph_def().nodes()) { - if (node.kind_case() == SavedObject::kBareConcreteFunction) { - const std::string& function_name = - node.bare_concrete_function().concrete_function_name(); - - const SavedConcreteFunction& saved_concrete_function = - metagraph.object_graph_def().concrete_functions().at(function_name); - - const FunctionDef* function_def = function_def_map.at(function_name); - std::unique_ptr<TFConcreteFunction> concrete_function; - TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction( - saved_concrete_function, function_def, revived_objects, context, - &concrete_function)); - (*restored_functions)[function_name] = std::move(concrete_function); - } else if (node.kind_case() == SavedObject::kFunction) { - // We only allow loading functions that have an annotated input signature, - // which means there is 1:1 correspondence between tf.function - // <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is - // the same restriction that MLIR has: - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707 - const SavedFunction& saved_function = node.function(); - if (saved_function.concrete_functions_size() != 1) { - return errors::FailedPrecondition( - "Only tf.functions annotated with an input signature are supported " - "by SavedModelAPI. This means that there should only be a single " - "ConcreteFunction per tf.function"); - } - const std::string& function_name = saved_function.concrete_functions(0); - const SavedConcreteFunction& saved_concrete_function = - metagraph.object_graph_def().concrete_functions().at(function_name); - - const FunctionDef* function_def = function_def_map.at(function_name); - - std::unique_ptr<TFConcreteFunction> concrete_function; - TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction( - saved_concrete_function, function_def, revived_objects, context, - &concrete_function)); - (*restored_functions)[function_name] = std::move(concrete_function); - } - } - return Status(); -} const TrackableObjectGraph::TrackableObject::SerializedTensor* FindSerializedTensorInTrackable( @@ -234,7 +109,7 @@ FindSerializedTensorInTrackable( // overridden "restore" method: // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85 Status RestoreCheckpoint(SavedModelV2Bundle* bundle, - const RevivedObjectMap& revived_objects, + const RevivedObjects& revived_objects, const std::string& directory, ImmediateExecutionContext* context) { // TODO(bmzhao): Batch up all the restores into a single restore op per @@ -254,8 +129,7 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle, return Status::OK(); } - Variable* variable = - down_cast<Variable*>(revived_objects.at(node).get()); + Variable* variable = revived_objects.variables.at(node).get(); // Restore the tensor's value from the checkpoint const TrackableObjectGraph::TrackableObject::SerializedTensor* @@ -289,6 +163,14 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle, return Status(); } +Status InitializeAllResources(const RevivedObjects& revived) { + for (const auto& node_and_resource : revived.restored_resources) { + const RestoredResource& resource = node_and_resource.second; + TF_RETURN_IF_ERROR(resource.Initialize()); + } + return Status(); +} + } // namespace Status TFSavedModelAPI::GetFunction(const std::string& function_path, @@ -299,20 +181,12 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path, return errors::NotFound("No saved object found at path ", function_path); } - const SavedObject& object = bundle_.saved_object_graph().nodes(*node); - if (object.kind_case() == SavedObject::kBareConcreteFunction) { - *function = - concrete_functions_ - .at(object.bare_concrete_function().concrete_function_name()) - .get(); - } else if (object.kind_case() == SavedObject::kFunction) { - *function = - concrete_functions_.at(object.function().concrete_functions(0)).get(); - } else { - return errors::InvalidArgument(function_path, - " is not a path to a Function."); + auto function_iter = revived_objects_.concrete_functions.find(*node); + if (function_iter == revived_objects_.concrete_functions.end()) { + return errors::NotFound("No function found at path ", function_path); } + *function = function_iter->second.get(); return Status(); } @@ -325,8 +199,8 @@ Status TFSavedModelAPI::GetSignatureDefFunction( std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() { std::vector<ConcreteFunction*> result; - result.reserve(concrete_functions_.size()); - for (auto& index_and_function : concrete_functions_) { + result.reserve(revived_objects_.concrete_functions.size()); + for (auto& index_and_function : revived_objects_.concrete_functions) { result.push_back(index_and_function.second.get()); } return result; @@ -340,34 +214,21 @@ Status TFSavedModelAPI::GetVariable(const std::string& variable_path, return errors::NotFound("No saved object found at path ", variable_path); } - const SavedObject& object = bundle_.saved_object_graph().nodes(*node); - - if (object.kind_case() == SavedObject::kVariable) { - auto iter = revived_objects_.find(*node); - if (iter == revived_objects_.end()) { - return errors::Internal("Variable ", variable_path, - " was not properly revived."); - } - *variable = static_cast<Variable*>(iter->second.get()); - return Status(); + auto variables_iter = revived_objects_.variables.find(*node); + if (variables_iter == revived_objects_.variables.end()) { + return errors::NotFound("No variable found at path ", variable_path); } - *variable = nullptr; - return errors::InvalidArgument( - variable_path, " is not a path to a Variable (kind=", object.kind_case(), - ")"); + *variable = variables_iter->second.get(); + return Status(); } -TFSavedModelAPI::TFSavedModelAPI( - const std::string& directory, SavedModelV2Bundle bundle, - std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> - revived_objects, - std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>> - concrete_functions) +TFSavedModelAPI::TFSavedModelAPI(const std::string& directory, + SavedModelV2Bundle bundle, + RevivedObjects revived_objects) : directory_(directory), bundle_(std::move(bundle)), - revived_objects_(std::move(revived_objects)), - concrete_functions_(std::move(concrete_functions)) {} + revived_objects_(std::move(revived_objects)) {} Status TFSavedModelAPI::Load( const std::string& directory, @@ -388,28 +249,25 @@ Status TFSavedModelAPI::Load( // This occurs in python here: // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454 - RevivedObjectMap revived_objects; - TF_RETURN_IF_ERROR(ReviveObjects(bundle.meta_graph_def(), context, directory, - &revived_objects)); + // Step 1: For each node in the graph, we should initialize an object of the + // corresponding type. For objects that depend on the initialization of other + // objects (like functions which capture resources), we will initialize them + // in step 2. + PartiallyRevivedObjects partially_revived_objects; + TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects( + bundle.meta_graph_def(), context, directory, &partially_revived_objects)); - // TODO(bmzhao): When we later add support for loading resources, we need to - // handle the case where materializing a function's captures requires invoking - // other functions. This occurs when retrieving the resource handle for a - // TrackableResource: - // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240 - // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233 - // This requires restoring functions in a topological sort order by capture - // dependencies. - ConcreteFunctionMap function_map; - TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects, - context, &function_map)); + RevivedObjects revived_objects; + TF_RETURN_IF_ERROR(partially_revived_objects.Build( + context, bundle.saved_object_graph(), &revived_objects)); TF_RETURN_IF_ERROR( RestoreCheckpoint(&bundle, revived_objects, directory, context)); + TF_RETURN_IF_ERROR(InitializeAllResources(revived_objects)); + out->reset(new TFSavedModelAPI(directory, std::move(bundle), - std::move(revived_objects), - std::move(function_map))); + std::move(revived_objects))); return Status(); } diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h index d108b4071e9..bc39a974ad2 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" @@ -72,19 +73,12 @@ class TFSavedModelAPI : public SavedModelAPI { Status GetVariable(const std::string& variable_path, Variable** variable); private: - TFSavedModelAPI( - const std::string& directory, SavedModelV2Bundle bundle, - std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> - revived_objects, - std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>> - concrete_functions); + TFSavedModelAPI(const std::string& directory, SavedModelV2Bundle bundle, + RevivedObjects revived_objects); std::string directory_; SavedModelV2Bundle bundle_; - std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> - revived_objects_; - std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>> - concrete_functions_; + RevivedObjects revived_objects_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index a55f232795b..ec5d65ea60d 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/tstring.h" @@ -196,6 +197,142 @@ TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) { TFE_DeleteContext(ctx); } +TEST_P(CSavedModelAPITest, LoadsStaticHashtableSavedModel) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = SavedModelPath("StaticHashTableModule"); + + TF_SavedModel* saved_model = + TF_LoadSavedModel(model_dir.c_str(), ctx, status); + + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TF_ConcreteFunction* lookup_fn = + TF_GetSavedModelConcreteFunction(saved_model, "lookup", status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // Note(bmzhao): Based on static_hashtable_asset.txt, we expect the following + // mapping: + // "foo" -> 0 + // "bar" -> 1 + // "baz" -> 2 + // "wombat" -> 3 + // all other strings -> -1 + + // Call lookup function with input "foo", expecting an output of 0 + { + std::vector<TFE_TensorHandle*> lookup_fn_inputs; + TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("foo")); + lookup_fn_inputs.push_back(input_foo); + + TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp( + lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + tensorflow::int64* output_value = + static_cast<tensorflow::int64*>(TF_TensorData(result)); + EXPECT_EQ(*output_value, 0); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(input_foo); + TFE_DeleteTensorHandle(lookup_fn_outputs[0]); + TFE_DeleteOp(lookup_op); + } + + // Call lookup function with input "baz", expecting an output of 2 + { + std::vector<TFE_TensorHandle*> lookup_fn_inputs; + TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("baz")); + lookup_fn_inputs.push_back(input_foo); + + TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp( + lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + tensorflow::int64* output_value = + static_cast<tensorflow::int64*>(TF_TensorData(result)); + EXPECT_EQ(*output_value, 2); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(input_foo); + TFE_DeleteTensorHandle(lookup_fn_outputs[0]); + TFE_DeleteOp(lookup_op); + } + + // Call lookup function w/input "NON-EXISTENT-KEY", expecting an output of -1 + { + std::vector<TFE_TensorHandle*> lookup_fn_inputs; + TFE_TensorHandle* input_foo = + TestScalarTensorHandle(ctx, tstring("NON-EXISTENT-KEY")); + lookup_fn_inputs.push_back(input_foo); + + TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp( + lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + tensorflow::int64* output_value = + static_cast<tensorflow::int64*>(TF_TensorData(result)); + EXPECT_EQ(*output_value, -1); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(input_foo); + TFE_DeleteTensorHandle(lookup_fn_outputs[0]); + TFE_DeleteOp(lookup_op); + } + + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 4efe56ff822..243f86ed787 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -213,6 +213,7 @@ py_binary( srcs = ["testdata/generate_saved_models.py"], data = [ ":saved_model_asset_data", + ":saved_model_static_hashtable_asset_data", ], python_version = "PY3", srcs_version = "PY3", @@ -220,6 +221,7 @@ py_binary( "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", "//tensorflow/python:tensor_spec", "//tensorflow/python:variables", "//tensorflow/python/compat:v2_compat", @@ -243,6 +245,7 @@ filegroup( "testdata/half_plus_two_v2/**", "testdata/x_plus_y_v2_debuginfo/**", "testdata/CyclicModule/**", + "testdata/StaticHashTableModule/**", "testdata/VarsAndArithmeticObjectGraph/**", "testdata/fuzz_generated/**", ]), @@ -260,6 +263,13 @@ filegroup( ], ) +filegroup( + name = "saved_model_static_hashtable_asset_data", + srcs = [ + "testdata/static_hashtable_asset.txt", + ], +) + exports_files( glob([ "testdata/half_plus_two_pbtxt/**", diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt new file mode 100644 index 00000000000..e79f591665f --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt @@ -0,0 +1,4 @@ +foo +bar +baz +wombat diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..04e8ba62bdbae4127ddd45e6617e908c126df22b GIT binary patch literal 14204 zcmeHO%X1t@8Q<CWeye9ow%0Aome+}~H?dY4*@?6SCzeQJo5WUR5fUDn8tsmxiB_{) zGb35%LI_nPn4+M73Plc35fleQAcPxGaHER90dAZ)aDhXLOTyRPJ5SB5B*zX3MXH>g zM^E?HUw@C^_fTFU?;7xpe)5k#n1z8$mZh&6&#R`cE;qI7ZPEjIQ?pu|W?d@{lRn67 zsm)EToN3s%+FH36#05#S+t(xsrXe%eGAx_+L$<YPw>NEPY=7rllC?F%YMG&r;Y|sO zmsM+Rk*)l^Ob4LHW|zvEYBZX-die+x>Mg@obwe{r*0N3ASfxplC?AEQwWgX{Low71 zE%53j9NM%r#b`C~fA~^i1MxeOUT(FTG<hab9)!ZWc1z(q5*sNGL6Ps%|B&r&9P(Eh znqlkqEdr3y4YogNkrv&}N{i(JNQ+_~__G9~R}9_O)u#SsjqmI!v$e6PZQD=jY?ox$ zE*ZyH@Es$zlq-)4;&h{-ZIhI4*k`1P_ulXPCzVe)gvJec2w<yKSC<tF3E&gZV#)db zeF<mtZptMnG_<C^q1!$Nr-)Qwd%p>T&$L?Wn{BZTTriUSBBYU<Rynn(8{`o5HnbIW zvuP_rZqkb=6toYr{MIV-`*IQZXf8zM9^el>k<x%XtC_3Ft+_R=zTR#j7c6oFhU)IE z(l)gfeOt3iIdT|CL$hq%P;I?sblxB*N1z`!r`ej)YBrRHZd$GkNNj&uNMCDRZ^N^Y zo>%MZDC-W9(qcur7Chq~kghq$A^A?c+z)AWTek$KlcfaAL2<sR*0nWURWo<1MNoyx z!;nD<ZEHmQ&~z<XlHdvGThP=-oL(YB_tMK26c&&xEt3fyKvpme*~J8q>myeW#=v11 z;3O!hcZ${2>x`{*5OQLiC1<&+SST)?D*{d3vd=IZa26$sv$F}=8|v0wu0{^J+T*wy z9&X(xz1&DNQ0H1k2}l+srlIUt>!zufq5lS2hKAN)ipk9gD&n$gbrg{l_`#TsQVJyl zjLi*hNg7ruMFScOe~{qF4O_LfmCfe7YND+`O0>qD+H4Y34%J+>WG3!R3xtJ~g%%76 zEeP-u<w0me@F|*@6_Q@rH0n4Yi$-xpUk#W&3O(W}%74pl1>TK7uePn#H&GZAc4YpS zOc9=dVMf1RU(i;y?WZ=I&1Y4+zQ#Df4IT;u3XmYZ1ZX~UZPiALVK@`PHGmu3I$uRy z#S_W$?@)@E9(mv9P}luxg!bLi+N!A{WzJm8zQgP28c_Bxb4PH!%^2=ErMq$qA+1CG z@oQcA6!hb9wQizEX=qza14s`Sap82%(j+7pt7d>8T~{p6#y7$~lM?9DEk$>|K%)fs zX*pwQ=9X5;$m8e7C*<6wv2L_(8q*T_@pJGwIILQhW-G|pHQ%Qu@!OKhVdzuDF)&$d zBavxH))JMX<CT*WId&tqhHv1j6675tLi{YPj6hbYH*r-|y?)Z$Y}IQ|I3o|fP$lG% zfw;?p3TK^(s{zhcBUl;1Ug}O2M;9zWzq^D$G-MaD@us0M%ntAoz68hIIhif4h75AZ zKzq~2Q5l^6FA~5Sq92809^ywru@m(BfeSn9uc4mcVY@)cCZ;#1@>py6CA5%J<Mx_v zHc+-q`_{Oy=_@Gf<Lz7aTFV%>O$1nokl00iJ#c}%)a)@a&6sMBO>959a=xxUBCkwq z%hMN?bFy-7vNGsk5`n0a69sq?O3?zXOny+2eia$#91NQcuN^y48Fhtv2nI=iK&X|; zSiZ*U4{Aw(zdgjaG-&XUN66iYwr{#P>fi+!^eq!B-oB}5Ec1ghsKRmI%=-~?-y(7J z2A7!cga<@K3H$k|fLs^ohLG8Te1Rkuz{9M_^^oR~&^`4LNy(CLhk|$ep&<vaC*!?( zoV%yHDDkwh(kr|@qE0!P8^&AV=QEI!B?%?>h*EDg8KTj*G-X3I)K$$?^oB*y^~91q z22?T8I4Lcoc?+*Z;D#Oq;}r4bxezcIB_oi7)ZF}M2PAg!G+rd+<kiHq^oo?Ij4_}w zvTzj1!C6)`Q#0zC@>o;dSZ=5nmGh54O~7Wt!80u_#Xz4Z&chvCl1GcExI!Sxy;b3G zYe-RTB6gzz6>T{3@rBtJ^Er~3y#c48ke5h?8y&n%k_@-2(3ejUU?Hw*Zz6jj0R#r} zY0~5OviP(gdh;1lP@G5%A6<oE=CaWsA|?<q&CQnC2->gGE|-i4%piPF4St=3qcMy^ zR>aCvBmvatH)uz#`h;-hAS0)8!-b%9750*@_)Xk%`8vzw>w$$bjQQouJs{<#skYlF z__7Swq0hpA!gS9h$oYeVp&xcR1Udc?LR3&kf`b~J3-Ua-7tiC!zvmo(fN}i9!YGXR zIj%5gk7$>1Sec%L?{?wM)V_IB=@D|sjoL?+BcEU#DKCt{@hFa<m0(e`GBZiun}L63 z;ddC{`3L%pAvp=7Hs*G?h;oaf2-w5S2*dIiU4IH@LPNaJfZM1QOc1dV?3*oF;6XPB z#zG&xlYuI+ID1EWI5=JPiW_N&L`@*yM1PWm1ARcc=dbA9Z-#bfRb_aWDP^7yQ&cQj zC(>QTl9vjrPzQE8?>vwoj~{H)^v1zqic~q|xS^g|K<AUFq!WpkAD7Vxb00@v#^B4% z9=cbwkH4*nK(01H&n+F~Kr1<xZA?x`ajsGZZ&l1Dy3#`i9a`L_F>32Y&LUQD;>kQ@ zYYFDQ<S!nKgXBO$Ry&qYkP^h$K=^A49u4Z8Qb&}-@)z=>0gV{u8sT?{fB#>OtK=Sk z&`5Rfq=_eAM?KJ4;|8W<Oi<K{OWr$2kcRJtAPt5EDIu5qq!Ql|!M`(L?d|_b9?}eN zp$RR5lP<xM1m^oXxf5lgR0v1|OIz<qv1sb;yMgqa2I{O&y|pVJ8_ac_`MaYFN8pg( zjWv00Uzwbg5$wlcQ~H%i2^)(U?i!Z(u)vYJ0vh&kGOL4}4`RSZ!*YiT!ps-{mp~91 zx3gGCQ@e<!sRpss!EqQ?u)<_2R$Hs<E4o{RB3x9^v?aTkws#4<)D?-5ekd{|;MOq+ z<`Wr8Luh~*UdvV{F&~Q=X$KXsdoK48MHBjMDJHiYnenCmm4lI4CC*Bq7j)CXnDtw- zrfSvdctt)Oy)dl@J?(y13B-Sg9mm@{TjcRyW}GFrJkR^BJ>tyx-2~Qo%von@c9O62 zkYZ~aZRLVIQ=eW|<(UhOnez=zn?66iQlGl8GNVmSPd{>DvdXC`xUnw%CRUp|<%-TP z9c4NcPPuyu6b1<cov_&*5oTNm7Ubg2Nk-Q$NT7BgTkp%<2hN}Ni{af?yhf6xWbBHl zi(XZLe%@b1IY2rnRUytX6lQZ6&L?V;#2lell0h?d=ti)ZQ)Ow<#@#0wqE=1FXhI{` zp+A#$my^QvB|7AB!(}Qx9yy`+e#qOV`!IO+u}KpdGy9?0>5n8B5@ouozX}+`D7wo= zGk};T$$T)sQJHj(=*;YC)D%8q_Lwqv`N_G@%wNB9ZSm6VGfyhlXRkgvx2Qb5aB2Rs zd?=C-dYe_#cpZ@)VO#SRB|j*U0le-Inc(g(dO`YornmeZ>HJa};%Lpu7$Lkm7QtEf z(CApqw!9G`J0Cvg=l}`(E<|~jh^U(@I;WZbR)W*JE9upUGCJWU@UEN7EeZO=RH5Nv zW<YX0dkzD$uxEFrMn1D&is+kCWE((~m#gk?lOChv{oF)!7%C4(5kt*Lc;ug*4lV5_ zE`o(Dq>H!RzL1V?062jhrVc&WzdO8(T-!UQnCOd{R&RmwNk7scL*5&go0TJB5XOL9 zqL#qJjyNq6@vJt7pzCpnp$Hy~fn2@B&uT<oQWxa)uPppg3PCR~6Jsh<=Rf3ZPU`>J zocol`xr?Gxl)sleKzUzNIjL9#<<t9*X)5Qf9DI7uJZm`Qi3o4#PbE0K;50P-x3LSR zuA*t7-;scNuRbIMbFAB%^li!C3g&f%1=0Qqk7_?Kbgeu+<qZgL%N-PUY_U;vWfy8G z`iP`6{u*O=(nuI2qt9X8e6+E*OU^8@k>P3g44{0<NMG>#$gWQKKKNT~vJ7Tzc^A+< z@uBebEc`}_Nfh}qa^p}ju*ps7B$33AsAga`epeNPRN<lS<DCjCZxfA7>y2Jx3yO}^ zs`wH04&Inb|Aw8nT}P1J-K_Kgo1h{Qc#v&W3ZfdukR5c&{n#ASCG&M+k3FzNf{XZk zVt2ah&TFFg76X*yDEVEO=*Vg}5`9{*mrL3sId^f7&64itd7l$?uO#6t9Q7wmu?qmZ z0b92)^*1#W$UCR#y+<6+MILl0E#w^7JwXt=bNJ;<{`={H9{LP62xCO)RpM(CWF#gB ztjH*5vAqsc(@AV$A7ZmLlnw0BZ(^&9ZYZrb8kdd<uUusI$3KL@bT;Hcaw_f>>K8Zv zED(JSlQE&r4YG+HB3y&XaX1`$sjxCbZDJrt@oO|y`A(1J1TGi5%}M;Az&@bVYpJVt zX0BqlbwjV)@-)!6PB`QtwRjE?>oeAxsMPT4@wmNh%QZl+E{e(Frcedf?)Gi%^&-sg zvmXZXS&|jzyL=48tW^#F)<Um~47X_fBk8<^y+X-K0{bj?_s$GszGpzx=QOti3HtMS zQsfnJ@uv`Q2phQu9TjrNEhOM!7|H_)A6OZkoYxhyp|(3zoD22o>_K$8f~hL}s|@*1 zu}<w568u1VC-psSrDUzom(0~0nApcIDc-qFhTcdGK=!4yr8QUBFK+?rh1?dlEn{al z>(0XZ6!xo`>o@sLALKjEOuQTfSii~Iv-t4T8>s{~-1mt`fq?{ys$sO4NMN3yr0%2= z15%}LCxy`Hb~4LyETtU$Qj&ijYDn-EM+qkUU&Y<9#!&xmI(_JbFaKdtUrFtx67Y^B z|7wr8`=dwn+mig2)W!UN1o`~%Y-JxsRW85y2PTE9PMZRmiY>L167Ahd0q}EV@>tH@ zPNqo$tFRZkF(>Y`6!~QL1k5p#%p_yhOp(iRxU@==$s#?zlNy2HofKxhZzu7Oinyi< zrQx)@N>m_Q`{2|9pY-%j>Xg|3u{$X&ly;^KuTXyAPU<9_@VSnm+sOgnEu*9GABZ;D Aj{pDw literal 0 HcmV?d00001 diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..f6d62d9a51c7440c26c963708ac9ea5eb5fbed66 GIT binary patch literal 88 zcmWH3w)^xiE;TMr4n`r?#Ny)A5+IXFh_xg!DJPYSkAqo=t2i?~FR`SwD7Bc2flG)> pfP+PdCq6STvm`SyC$lQG2q?<W#mm7e#2KHKnUk8An48MQ000`+7Ipvt literal 0 HcmV?d00001 diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..df6c85e578347341989e42a877ca8d166bc92972 GIT binary patch literal 144 zcmZQzVB=tvV&Y(Akl~JZ_HcFf4)FK%3vqPvagFzP@^W<!iFXfj4DjG!7h=$eFkS1~ xF2Dc;j37c_)`2=W=fJ0h8yFaw5*V!ELN4D9dOC3U@__^hf$-lA-72N-w*e!H7l;4= literal 0 HcmV?d00001 diff --git a/tensorflow/cc/saved_model/testdata/generate_saved_models.py b/tensorflow/cc/saved_model/testdata/generate_saved_models.py index 91c09d33bb0..2b64cf52096 100644 --- a/tensorflow/cc/saved_model/testdata/generate_saved_models.py +++ b/tensorflow/cc/saved_model/testdata/generate_saved_models.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.module import module from tensorflow.python.ops import io_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.saved_model import save_options @@ -82,10 +83,31 @@ class AssetModule(module.Module): return io_ops.read_file(self.asset) +class StaticHashTableModule(module.Module): + """A module with an Asset, StaticHashTable, and a lookup function.""" + + def __init__(self): + self.asset = tracking.Asset( + test.test_src_dir_path( + "cc/saved_model/testdata/static_hashtable_asset.txt")) + self.table = lookup_ops.StaticHashTable( + lookup_ops.TextFileInitializer(self.asset, dtypes.string, + lookup_ops.TextFileIndex.WHOLE_LINE, + dtypes.int64, + lookup_ops.TextFileIndex.LINE_NUMBER), + -1) + + @def_function.function( + input_signature=[tensor_spec.TensorSpec(shape=None, dtype=dtypes.string)]) + def lookup(self, word): + return self.table.lookup(word) + + MODULE_CTORS = { "VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph, "CyclicModule": CyclicModule, "AssetModule": AssetModule, + "StaticHashTableModule": StaticHashTableModule, } diff --git a/tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt b/tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt new file mode 100644 index 00000000000..e79f591665f --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt @@ -0,0 +1,4 @@ +foo +bar +baz +wombat