diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index d6c3613d194..bc532506a46 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -66,8 +66,13 @@ cc_library( "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/experimental/saved_model/core/revived_types:asset", "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects", + "//tensorflow/c/experimental/saved_model/core/revived_types:restored_resource_revival_state", "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function_revival_state", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_signature_def_function_revival_state", "//tensorflow/c/experimental/saved_model/core/revived_types:variable", + "//tensorflow/cc/saved_model:loader_util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -147,12 +152,14 @@ cc_library( "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/experimental/saved_model/core/ops:restore_ops", "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:flat_tensor_function", + "//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects", + "//tensorflow/c/experimental/saved_model/core/revived_types:revived_objects", "//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible", "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", "//tensorflow/c/experimental/saved_model/core/revived_types:variable", "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/cc/saved_model:constants", - "//tensorflow/cc/saved_model:loader_util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 1205f12b948..eaccc2fac69 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -69,6 +69,84 @@ cc_library( ], ) +cc_library( + name = "partially_revived_objects", + srcs = [ + "partially_revived_objects.cc", + ], + hdrs = [ + "partially_revived_objects.h", + ], + deps = [ + ":asset", + ":constant", + ":restored_resource", + ":restored_resource_revival_state", + ":revived_objects", + ":tf_concrete_function", + ":tf_concrete_function_revival_state", + ":tf_signature_def_function", + ":tf_signature_def_function_revival_state", + ":variable", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "restored_resource", + srcs = [ + "restored_resource.cc", + ], + hdrs = [ + "restored_resource.h", + ], + deps = [ + ":tensorhandle_convertible", + ":tf_concrete_function", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "restored_resource_revival_state", + hdrs = [ + "restored_resource_revival_state.h", + ], + deps = [ + ":tf_concrete_function_revival_state", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + ], +) + +cc_library( + name = "revived_objects", + hdrs = [ + "revived_objects.h", + ], + deps = [ + ":asset", + ":constant", + ":restored_resource", + ":tf_concrete_function", + ":tf_signature_def_function", + ":variable", + "//tensorflow/core:lib", + ], +) + cc_library( name = "variable", srcs = [ @@ -123,6 +201,21 @@ cc_library( ], ) +cc_library( + name = "tf_concrete_function_revival_state", + hdrs = [ + "tf_concrete_function_revival_state.h", + ], + deps = [ + ":tf_concrete_function", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "tf_signature_def_function", srcs = [ @@ -145,3 +238,17 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "tf_signature_def_function_revival_state", + hdrs = [ + "tf_signature_def_function_revival_state.h", + ], + deps = [ + ":tf_signature_def_function", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc new file mode 100644 index 00000000000..5cc06e6c54f --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc @@ -0,0 +1,388 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +namespace { + +Status AssertAllCreateResourceFunctionsHaveNoCaptures( + const PartiallyRevivedObjects& objects) { + for (const auto& id_and_resource : objects.restored_resources) { + int node_id = id_and_resource.first; + const RestoredResourceRevivalState& resource = id_and_resource.second; + const TFConcreteFunctionRevivalState* create_resource_fn = + resource.create_resource; + if (create_resource_fn == nullptr) { + return errors::FailedPrecondition( + "Resource at node ", node_id, + " did not have a create_resource() function"); + } + const SavedConcreteFunction* saved_create_resource_fn = + create_resource_fn->saved_concrete_func; + if (!saved_create_resource_fn->bound_inputs().empty()) { + // TODO(b/124045874): Support loading resource functions via a top sort + return errors::Unimplemented( + "Create Resource functions with captures are currently unsupported."); + } + } + return Status(); +} + +// Retrieves the TensorHandle associated with `node_id` from `obj_graph`, and +// set `*handle` to point to it. +Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + ImmediateExecutionTensorHandle** handle) { + const SavedObject& node = obj_graph.nodes(node_id); + SavedObject::KindCase kind = node.kind_case(); + switch (kind) { + case SavedObject::kVariable: { + const auto& variables_iter = objects.variables.find(node_id); + if (variables_iter == objects.variables.end()) { + return errors::FailedPrecondition( + "Tried to convert node id ", node_id, + " of type variable to tensor but the variable wasn't initialized"); + } + *handle = variables_iter->second->handle(); + return Status(); + } + case SavedObject::kConstant: { + const auto& constants_iter = objects.constants.find(node_id); + if (constants_iter == objects.constants.end()) { + return errors::FailedPrecondition("Tried to convert node id ", node_id, + " of type constant to tensor but the " + "constant wasn't initialized"); + } + *handle = constants_iter->second->handle(); + return Status(); + } + case SavedObject::kAsset: { + const auto& assets_iter = objects.assets.find(node_id); + if (assets_iter == objects.assets.end()) { + return errors::FailedPrecondition( + "Tried to convert node id ", node_id, + " of type asset to tensor but the asset wasn't initialized"); + } + *handle = assets_iter->second->handle(); + return Status(); + } + case SavedObject::kResource: { + const auto& resource_iter = objects.restored_resources.find(node_id); + if (resource_iter == objects.restored_resources.end()) { + return errors::FailedPrecondition( + "Tried to convert node id ", node_id, + " of type Resource to tensor but the Resource wasn't initialized"); + } + const RestoredResourceRevivalState& resource = resource_iter->second; + if (resource.resource_handle == nullptr) { + return errors::FailedPrecondition( + "Resource with node id ", node_id, + " should have its resource_handle created, but was nullptr."); + } + *handle = resource.resource_handle.get(); + return Status(); + } + default: { + return errors::FailedPrecondition( + "Only objects of type variable, constant, asset, and resources have " + "capturable tensorhandles. Encountered object of kind ", + node.kind_case(), " at node id: ", node_id); + } + } +} + +// This function finds the necessary captures, then forwards to the builder +// method +Status CreateConcreteFunction(ImmediateExecutionContext* ctx, + const TFConcreteFunctionRevivalState& builder, + const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + std::unique_ptr* out) { + const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs(); + std::vector captures; + captures.reserve(capture_node_ids.size()); + for (int capture_node_id : capture_node_ids) { + ImmediateExecutionTensorHandle* capture_handle; + TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects, + &capture_handle)); + captures.push_back(capture_handle); + } + // TODO(bmzhao): Create Metadata here + return TFConcreteFunction::Create(/*function_def=*/builder.fdef, + /*captures=*/std::move(captures), + /*metadata=*/{}, + /*ctx=*/ctx, + /*out=*/out); +} + +Status CreateSignatureDefFunction( + ImmediateExecutionContext* ctx, + const TFSignatureDefFunctionRevivalState& builder, + const SavedObjectGraph& obj_graph, const PartiallyRevivedObjects& objects, + std::unique_ptr* out) { + const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs(); + std::vector captures; + captures.reserve(capture_node_ids.size()); + for (int capture_node_id : capture_node_ids) { + ImmediateExecutionTensorHandle* capture_handle; + TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects, + &capture_handle)); + captures.push_back(capture_handle); + } + // TODO(bmzhao): Create Metadata here + return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef, + /*captures=*/std::move(captures), + /*metadata=*/{}, + /*ctx=*/ctx, + /*out=*/out); +} + +Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + RevivedObjects* revived) { + for (const auto& id_and_resource : objects.restored_resources) { + const RestoredResourceRevivalState& resource = id_and_resource.second; + const TFConcreteFunctionRevivalState* create_resource_fn = + resource.create_resource; + + const SavedConcreteFunction* saved_create_resource_fn = + create_resource_fn->saved_concrete_func; + if (!saved_create_resource_fn->bound_inputs().empty()) { + // TODO(b/124045874): Load resource functions via a topological sort + return errors::Unimplemented( + "Create Resource functions with captures are currently unsupported."); + } + std::unique_ptr out; + TF_RETURN_IF_ERROR(CreateConcreteFunction(ctx, *create_resource_fn, + obj_graph, objects, &out)); + revived->concrete_functions[create_resource_fn->node_id] = std::move(out); + } + return Status(); +} + +Status InitializeAllFunctions(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + const PartiallyRevivedObjects& objects, + RevivedObjects* revived) { + gtl::FlatMap>* destination_func_map = + &revived->concrete_functions; + gtl::FlatMap>* + destination_sig_map = &revived->signature_def_functions; + + for (const auto& id_and_func : objects.concrete_functions) { + int node_id = id_and_func.first; + const TFConcreteFunctionRevivalState& func = id_and_func.second; + + if (destination_func_map->find(node_id) != destination_func_map->end()) { + // The function has already been initialized in the destination_map, + // so we can skip this node. This can occur because we initialize + // CreateResource functions before calling this function. + continue; + } + + std::unique_ptr out; + TF_RETURN_IF_ERROR( + CreateConcreteFunction(ctx, func, obj_graph, objects, &out)); + (*destination_func_map)[node_id] = std::move(out); + } + + for (const auto& id_and_func : objects.signature_def_functions) { + int node_id = id_and_func.first; + const TFSignatureDefFunctionRevivalState& func = id_and_func.second; + + if (destination_sig_map->find(node_id) != destination_sig_map->end()) { + continue; + } + + std::unique_ptr out; + TF_RETURN_IF_ERROR( + CreateSignatureDefFunction(ctx, func, obj_graph, objects, &out)); + (*destination_sig_map)[node_id] = std::move(out); + } + + return Status(); +} + +Status CreateAllResourceHandles(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + PartiallyRevivedObjects* objects, + RevivedObjects* revived) { + for (auto& id_and_resource : objects->restored_resources) { + RestoredResourceRevivalState& resource = id_and_resource.second; + int create_resource_fn_node = resource.create_resource->node_id; + const gtl::FlatMap>& + revived_functions = revived->concrete_functions; + + const auto& revived_functions_iter = + revived_functions.find(create_resource_fn_node); + if (revived_functions_iter == revived_functions.end()) { + return errors::FailedPrecondition( + "ConcreteFunction at node ", create_resource_fn_node, + " should have been initialized prior to being called."); + } + const TFConcreteFunction& create_resource_fn = + *revived_functions_iter->second; + ImmediateOpPtr function_op; + TF_RETURN_IF_ERROR(create_resource_fn.MakeCallOp({}, &function_op)); + TF_RETURN_IF_ERROR(function_op->SetDeviceName(resource.device.c_str())); + + AbstractTensorHandle* resource_handle = nullptr; + int num_retvals = 1; + TF_RETURN_IF_ERROR(function_op->Execute( + absl::MakeSpan(&resource_handle, num_retvals), &num_retvals)); + AbstractTensorHandlePtr owned_resource_handle(resource_handle); + if (!tensorflow::isa( + owned_resource_handle.get())) { + return errors::Internal("Unexpected tensor handle kind."); + } + ImmediateTensorHandlePtr result( + reinterpret_cast( + owned_resource_handle.release())); + resource.resource_handle = std::move(result); + } + return Status(); +} + +// Finds a ConcreteFunction with node id `node` in `objects`, and sets *out to +// point to it. If node doesn't exist in `objects`, out is untouched, and an +// error status is returned. +Status FindConcreteFunction(int node, RevivedObjects* objects, + TFConcreteFunction** out) { + auto func_iter = objects->concrete_functions.find(node); + if (func_iter == objects->concrete_functions.end()) { + return errors::FailedPrecondition( + "Failed to find ConcreteFunction with node id ", node, + " in revived objects"); + } + *out = func_iter->second.get(); + return Status(); +} + +Status BuildResources(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + PartiallyRevivedObjects* objects, + RevivedObjects* revived) { + for (auto& id_and_resource : objects->restored_resources) { + int node_id = id_and_resource.first; + RestoredResourceRevivalState& resource_revival_state = + id_and_resource.second; + + TFConcreteFunction* create_resource = nullptr; + + // Check all the functions associated with the resource have already been + // initialized in `revived` + if (resource_revival_state.create_resource != nullptr) { + TF_RETURN_IF_ERROR( + FindConcreteFunction(resource_revival_state.create_resource->node_id, + revived, &create_resource)); + } + + TFConcreteFunction* initialize = nullptr; + if (resource_revival_state.initialize != nullptr) { + TF_RETURN_IF_ERROR(FindConcreteFunction( + resource_revival_state.initialize->node_id, revived, &initialize)); + } + + TFConcreteFunction* destroy_resource = nullptr; + if (resource_revival_state.destroy_resource != nullptr) { + TF_RETURN_IF_ERROR( + FindConcreteFunction(resource_revival_state.destroy_resource->node_id, + revived, &destroy_resource)); + } + + if (resource_revival_state.resource_handle == nullptr) { + return errors::FailedPrecondition("Resource at node id ", node_id, + " does not have a resource handle."); + } + + revived->restored_resources.emplace( + node_id, RestoredResource( + /*device=*/resource_revival_state.device, + /*create_resource=*/create_resource, + /*initialize=*/initialize, + /*destroy_resource=*/destroy_resource, + /*resource_handle=*/ + std::move(resource_revival_state.resource_handle))); + } + return Status(); +} + +} // namespace + +Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + RevivedObjects* revived) { + // Step 1: We would like to initialize all functions; this requires setting up + // their captured tensorhandles, which may come from variables, assets, + // constants, or resources. The first three are trivial; However, + // tensorhandles that correspond to resources must be created by invoking + // their "create_resource" function. + // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240 + // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233 + // For now, we assert that all create_resource functions must have no + // captures. This aligns with the current behavior in python. + // https://github.com/tensorflow/tensorflow/blob/50eac986bf7a0ad12594e080f083181f277e0b49/tensorflow/python/saved_model/load.py#L152-L155 + // TODO(bmzhao): We should do a topological sort instead. + + // 1a. Make sure all CreateResource functions have no captures. + TF_RETURN_IF_ERROR(AssertAllCreateResourceFunctionsHaveNoCaptures(*this)); + + // 1b. Initialize all CreateResource functions, storing them in `revived` + TF_RETURN_IF_ERROR( + InitializeCreateResourceFunctions(ctx, obj_graph, *this, revived)); + + // 1c. Invoke all "CreateResource" functions and store their ResourceHandles + // https://github.com/tensorflow/tensorflow/blob/3b6b41b68a95dc70c26dc816b29d359bfb88c116/tensorflow/python/training/tracking/tracking.py#L241-L247 + // in *this->resources. + // TODO(bmzhao): Maybe store them separately, not in *this? + TF_RETURN_IF_ERROR(CreateAllResourceHandles(ctx, obj_graph, this, revived)); + + // 2. Initialize all the rest of the functions + TF_RETURN_IF_ERROR(InitializeAllFunctions(ctx, obj_graph, *this, revived)); + + // 3a. Move over all non-function, non-resource objects + revived->variables = std::move(variables); + revived->assets = std::move(assets); + revived->constants = std::move(constants); + + // 3b. Move over resources. + TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived)); + + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h new file mode 100644 index 00000000000..cce52748a1d --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ + +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// Container for objects during the revival step in SavedModel's loading. +// Notably, resources and functions can be in a state where they reference +// other resources/functions that have not been constructed yet. We collect +// *all* objects in a partially valid state here, then properly initialize +// resources and functions. +struct PartiallyRevivedObjects { + gtl::FlatMap> variables; + gtl::FlatMap> assets; + gtl::FlatMap> constants; + gtl::FlatMap concrete_functions; + gtl::FlatMap signature_def_functions; + gtl::FlatMap restored_resources; + + Status Build(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, RevivedObjects* revived); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc new file mode 100644 index 00000000000..47860ce8b39 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +Status ExecuteNoArgDummyReturnFunction(TFConcreteFunction* func) { + ImmediateOpPtr function_op; + TF_RETURN_IF_ERROR(func->MakeCallOp({}, &function_op)); + + AbstractTensorHandle* dummy_output = nullptr; + int num_retvals = 1; + TF_RETURN_IF_ERROR(function_op->Execute( + absl::MakeSpan(&dummy_output, num_retvals), &num_retvals)); + AbstractTensorHandlePtr owned_dummy_output(dummy_output); + return Status(); +} + +} // namespace + +RestoredResource::RestoredResource(const std::string& device, + TFConcreteFunction* create_resource, + TFConcreteFunction* initialize, + TFConcreteFunction* destroy_resource, + ImmediateTensorHandlePtr resource_handle) + : TensorHandleConvertible(std::move(resource_handle)), + device_(device), + create_resource_(create_resource), + initialize_(initialize), + destroy_resource_(destroy_resource) {} + +Status RestoredResource::Initialize() const { + return ExecuteNoArgDummyReturnFunction(initialize_); +} + +RestoredResource::~RestoredResource() { + // Note(bmzhao): SavedModels saved before + // https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3 + // did not have their destroy_resource function saved, meaning they will + // leak resources. + if (destroy_resource_ != nullptr) { + Status status = ExecuteNoArgDummyReturnFunction(destroy_resource_); + if (!status.ok()) { + LOG(WARNING) + << "Failed executing destroy_resource function for RestoredResource: " + << status.error_message(); + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h new file mode 100644 index 00000000000..7adbd563a6b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" + +namespace tensorflow { + +// RestoredResource represents a TF2 "Resource" object loaded from a savedmodel, +// analogous to the Python _RestoredResource object: +// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/saved_model/load.py#L481 +// TF2 resource objects typically extend TrackableResource: +// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/training/tracking/tracking.py#L285 +// and are expected to implement "_create_resource", "_initialize", and +// "_destroy_resource" functions: +// https://github.com/tensorflow/tensorflow/blob/139ba9c5284799beafdd1d7f895127cf00e7c48f/tensorflow/python/training/tracking/tracking.py#L262-L281 +class RestoredResource : TensorHandleConvertible { + public: + // Note(bmzhao): RestoredResource stores non-owning pointers to its associated + // functions because SavedModel internally owns all functions and objects in + // the RevivedObjects struct (which owns all functions). One alternative would + // be to have RevivedObjects store shared_ptr instead, and + // change RestoredResource's constructor take shared_ptr. + // To keep things simple, I've stuck to raw pointers for now. + // + // Params: + // device - The device string associated with the SavedResource + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_object_graph.proto#L182 + // Conceptually, this is the same device used in CapturableResource: + // https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L222-L225 + // Implementation-wise, it is device used when invoking the + // create_resource function to produce the resource_handle + // associated with the object: + // https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L246-L247 + // create_resource - Non owning pointer to the create_resource function + // associated with this object. Must be NON-NULL. + // initialize - Non owning pointer to the initialize function associated with + // this object. Must be NON-NULL. + // destroy_resource - Non owning pointer to the destroy_resource function + // associated with this object. Ideally this should be + // NON-NULL, but in order to support models saved prior to + // https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3 + // we allow null here. This will, however, leak resources. + RestoredResource(const std::string& device, + TFConcreteFunction* create_resource, + TFConcreteFunction* initialize, + TFConcreteFunction* destroy_resource, + ImmediateTensorHandlePtr resource_handle); + + Status Initialize() const; + + // RestoredResource is movable, but not copyable. + RestoredResource(RestoredResource&& other) = default; + RestoredResource& operator=(RestoredResource&& other) = default; + + ~RestoredResource() override; + + private: + std::string device_; + TFConcreteFunction* create_resource_; + TFConcreteFunction* initialize_; + TFConcreteFunction* destroy_resource_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h new file mode 100644 index 00000000000..48d00308cc1 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ + +#include + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" + +namespace tensorflow { + +// All "Resources" should have these 3 saved functions: +// https://github.com/tensorflow/tensorflow/blob/86dc281333d7d277ddc1882f2bca4b17e7ec40e5/tensorflow/python/training/tracking/tracking.py#L277-L281 +struct RestoredResourceRevivalState { + std::string device; + TFConcreteFunctionRevivalState* create_resource = nullptr; + TFConcreteFunctionRevivalState* initialize = nullptr; + TFConcreteFunctionRevivalState* destroy_resource = nullptr; + ImmediateTensorHandlePtr resource_handle = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h b/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h new file mode 100644 index 00000000000..3b8fb0f6a8a --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ + +#include +#include + +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { + +// RevivedObjects is mainly used as a container for all the "state" owned by +// SavedModel. It stores all non-"user object" nodes from a SavedModel +// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62) +// in a "fully constructed" state. It is effectively a strongly typed map, where +// each member is a map from the node id in the SavedObjectGraph's nodes +// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29) +// to the revived object of the corresponding type. +struct RevivedObjects { + gtl::FlatMap> variables; + gtl::FlatMap> assets; + gtl::FlatMap> constants; + gtl::FlatMap> concrete_functions; + gtl::FlatMap> + signature_def_functions; + gtl::FlatMap restored_resources; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h new file mode 100644 index 00000000000..3dd7a6eecc4 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// TFConcreteFunctionRevivalState wraps the state needed for building a +// TF_ConcreteFunction. This is mainly used in PartiallyRevivedObjects, which +// wraps partially constructed Function and Resource objects. +struct TFConcreteFunctionRevivalState { + // Index of the node in the SavedObjectGraph it was loaded from. + int node_id; + + // Pointer to the original functiondef. fdef_ is guaranteed to be + // non-null. + const FunctionDef* fdef; + + // TensorHandle captures for this funtion + std::vector captures; + + // SavedConcreteFunction contains much of the metadata of the expected "types" + // of the inputs and outputs of a function. + // Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null. + const SavedConcreteFunction* saved_concrete_func; + + // This field is only present on TF2 ConcreteFunctions, and is useful for + // determining the original argument *names* of the function, (since the + // "canonicalized_input_signature" may append extra uniquifying integers). + // However, SavedBareConcreteFunctions do not have a FunctionSpec. + // Note(bmzhao): if function_spec_.has_value(), *function_spec_ is guaranteed + // to be non-null. + absl::optional function_spec; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h new file mode 100644 index 00000000000..ac1b20e474b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// FunctionBuilder wraps the state needed for building a SignatureDefFunction. +// This is mainly used in PartiallyRevivedObjects, which wraps partially +// constructed Function and Resource objects. +struct TFSignatureDefFunctionRevivalState { + // Index of the node in the SavedObjectGraph it was loaded from. + int node_id = 0; + + // Pointer to the original functiondef. fdef_ is guaranteed to be + // non-null. + const FunctionDef* fdef = nullptr; + + // SavedConcreteFunction contains much of the metadata of the expected "types" + // of the inputs and outputs of a function. + // Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null. + const SavedConcreteFunction* saved_concrete_func = nullptr; + + // The name of the SignatureDef key. + std::string signature_key; + + // TensorHandle captures for this funtion + std::vector captures; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 9ae7778e0cf..ef75f1a382b 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -17,15 +17,22 @@ limitations under the License. #include #include +#include +#include #include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/tf_tensor_internal.h" +#include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" @@ -41,6 +48,83 @@ namespace { using StructuredValueDictEntry = protobuf::MapPair; +// Maps from a Nodedef's name to its corresponding AttrValues, for a given +// Graphdef +using NodeAttrMap = + gtl::FlatMap; + +// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary +using FunctionDefMap = gtl::FlatMap; + +// Looks up a SavedConstant's associated tensorproto from the NodeAttrMap and +// returns a tensorflow::Constant. +Status ConstantFromSavedConstant( + ImmediateExecutionContext* ctx, + const tensorflow::SavedConstant& saved_constant, + const NodeAttrMap& node_attr_map, std::unique_ptr* output) { + const std::string& const_op_name = saved_constant.operation(); + const auto& node_name_and_attrs = node_attr_map.find(const_op_name); + if (node_name_and_attrs == node_attr_map.end()) { + return errors::FailedPrecondition( + "Unable to find Const operation with name'", const_op_name, + "' in SavedModel graphdef"); + } + const AttrValueMap* attrs = node_name_and_attrs->second; + const auto& attr_name_and_value = attrs->find("value"); + if (attr_name_and_value == attrs->end()) { + return errors::FailedPrecondition("Unable to find Const operation '", + const_op_name, "'s value attribute"); + } + const TensorProto& tensor_proto = attr_name_and_value->second.tensor(); + return internal::TensorProtoToConstant(ctx, tensor_proto, output); +} + +// Finds the "signatures" object in the object graph, and fills a mapping of +// each signature's name to the corresponding function's node in the object +// graph. +Status GetSignaturesMap(const SavedObjectGraph& saved_objects, + gtl::FlatMap* signatures_map) { + if (saved_objects.nodes().empty()) { + return errors::FailedPrecondition("Saved Object Graph was empty."); + } + const SavedObject& root = saved_objects.nodes(0); + const SavedObject* signatures = nullptr; + for (const auto& child : root.children()) { + if (child.local_name() == "signatures") { + if (child.node_id() >= saved_objects.nodes().size()) { + return errors::FailedPrecondition( + "Signature object had child node id ", child.node_id(), + " which exceeds the size of the set of nodes"); + } + signatures = &saved_objects.nodes(child.node_id()); + } + } + + // Some basic sanity checks that this object is actually our "signatures" map + if (signatures == nullptr) { + // This is where the "signatures" attribute is always set: + // https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109 + return errors::FailedPrecondition( + "SavedObjectGraph's root object must have a child 'signatures' object"); + } + if (signatures->kind_case() != SavedObject::kUserObject) { + return errors::FailedPrecondition( + "Signatures must be a SavedObject of type UserObject."); + } + if (signatures->user_object().identifier() != "signature_map") { + // This is where the string comes from: + // https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220 + return errors::FailedPrecondition( + "Signatures SavedObject must have identifier 'signature_map'."); + } + + for (const auto& child : signatures->children()) { + (*signatures_map)[child.local_name()] = child.node_id(); + } + return Status(); +} + // Perform some basic sanity checks on SavedConcreteFunction's input and // output signatures with respect to the corresponding FunctionDef's input // and output args. @@ -99,6 +183,21 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef( return Status(); } +Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) { + // We only allow loading functions that have an annotated input signature, + // which means there is 1:1 correspondence between tf.function + // <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is + // the same restriction that MLIR has: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707 + if (saved_function.concrete_functions_size() != 1) { + return errors::FailedPrecondition( + "Only tf.functions annotated with an input signature are supported " + "by SavedModelAPI. This means that there should only be a single " + "ConcreteFunction per tf.function"); + } + return Status(); +} + } // namespace Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset, @@ -255,21 +354,18 @@ absl::optional FindNodeAtPath(StringPiece path, return node_id; } -std::unordered_map -NodeToAttrMap(const tensorflow::GraphDef& graphdef) { - std::unordered_map - result; +gtl::FlatMap NodeToAttrMap( + const tensorflow::GraphDef& graphdef) { + gtl::FlatMap result; for (const tensorflow::NodeDef& node : graphdef.node()) { result[node.name()] = &node.attr(); } return result; } -std::unordered_map +gtl::FlatMap FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) { - std::unordered_map + gtl::FlatMap result; for (const FunctionDef& function_def : library.function()) { result[function_def.signature().name()] = &function_def; @@ -277,5 +373,154 @@ FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) { return result; } +Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph, + ImmediateExecutionContext* context, + const std::string& directory, + PartiallyRevivedObjects* objects) { + // This is needed to restore "Constant" nodes by looking up their + // "Value" attribute. + NodeAttrMap node_attr_map = NodeToAttrMap(metagraph.graph_def()); + + // These are needed for creating "Assets", by looking up their filenames. + std::vector assets; + TF_RETURN_IF_ERROR(GetAssetFileDefs(metagraph, &assets)); + + // Signatures are needed for determining whether a function is a + // SignatureDefFunction or not. + gtl::FlatMap signatures_map; + TF_RETURN_IF_ERROR( + GetSignaturesMap(metagraph.object_graph_def(), &signatures_map)); + + gtl::FlatMap reversed_signatures_map; + reversed_signatures_map.reserve(signatures_map.size()); + for (const auto& signature_key_and_node : signatures_map) { + reversed_signatures_map.emplace(signature_key_and_node.second, + signature_key_and_node.first); + } + + // FunctionDefs are needed to help construct + // TFConcreteFunction/SignatureDefFunctions + const FunctionDefMap function_def_map = + internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library()); + + // Iterate through all the saved objects, restoring objects (if we can) as we + // go. For objects that dependencies on other objects (resources/functions), + // we partially initialize "builders" that correspond to their currently known + // state, and gradually fill them out in subsequent passes. + for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) { + const SavedObject& node = metagraph.object_graph_def().nodes(i); + if (node.kind_case() == SavedObject::kVariable) { + std::unique_ptr variable; + TF_RETURN_IF_ERROR( + LoadSavedVariable(context, node.variable(), &variable)); + objects->variables[i] = std::move(variable); + } else if (node.kind_case() == SavedObject::kConstant) { + std::unique_ptr constant; + TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(), + node_attr_map, &constant)); + objects->constants[i] = std::move(constant); + } else if (node.kind_case() == SavedObject::kAsset) { + std::unique_ptr asset; + TF_RETURN_IF_ERROR( + LoadSavedAsset(context, node.asset(), directory, assets, &asset)); + objects->assets[i] = std::move(asset); + } else if (node.kind_case() == SavedObject::kResource) { + RestoredResourceRevivalState resource_revival_state; + // We'll set the resource's functions in a subsequent pass, once we get + // all functions in a partially revived state. + resource_revival_state.device = node.resource().device(); + objects->restored_resources[i] = std::move(resource_revival_state); + } else if (node.kind_case() == SavedObject::kFunction) { + // Get the SavedFunction node and validate it has a single concrete func. + const SavedFunction& saved_function = node.function(); + TF_RETURN_IF_ERROR(ValidateSingleConcreteFunction(saved_function)); + + // Retrieve related function information. + const std::string& function_name = saved_function.concrete_functions(0); + const FunctionDef* function_def = function_def_map.at(function_name); + const SavedConcreteFunction& saved_concrete_func = + metagraph.object_graph_def().concrete_functions().at(function_name); + const FunctionSpec& function_spec = saved_function.function_spec(); + + // Construct either a SignatureDefFunctionBuilder or a + // ConcreteFunctionBuilder, depending on whether this node was a child + // of the "signatures" attribute from root object. + auto reverse_signature_iter = reversed_signatures_map.find(i); + if (reverse_signature_iter != reversed_signatures_map.end()) { + TFSignatureDefFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + func_revival_state.signature_key = reverse_signature_iter->second; + objects->signature_def_functions[i] = std::move(func_revival_state); + } else { + TFConcreteFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + func_revival_state.function_spec = &function_spec; + objects->concrete_functions[i] = std::move(func_revival_state); + } + } else if (node.kind_case() == SavedObject::kBareConcreteFunction) { + const SavedBareConcreteFunction& bare_cf = node.bare_concrete_function(); + + // Retrieve related function information. + const std::string& function_name = bare_cf.concrete_function_name(); + const FunctionDef* function_def = function_def_map.at(function_name); + const SavedConcreteFunction& saved_concrete_func = + metagraph.object_graph_def().concrete_functions().at(function_name); + + // Check whether this is a SignatureDefFunction, or not. + auto reverse_signature_iter = reversed_signatures_map.find(i); + if (reverse_signature_iter != reversed_signatures_map.end()) { + TFSignatureDefFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + func_revival_state.signature_key = reverse_signature_iter->second; + objects->signature_def_functions[i] = std::move(func_revival_state); + } else { + TFConcreteFunctionRevivalState func_revival_state; + func_revival_state.node_id = i; + func_revival_state.fdef = function_def; + func_revival_state.saved_concrete_func = &saved_concrete_func; + objects->concrete_functions[i] = std::move(func_revival_state); + } + } + } + + // Now that we've partially restored all functions, we can have resources + // point to them + for (auto& node_and_resource_revival_state : objects->restored_resources) { + int node_id = node_and_resource_revival_state.first; + const SavedObjectGraph& obj_graph = metagraph.object_graph_def(); + const SavedObject& node = obj_graph.nodes(node_id); + RestoredResourceRevivalState& resource = + node_and_resource_revival_state.second; + for (const TrackableObjectGraph::TrackableObject::ObjectReference& child : + node.children()) { + int child_node_id = child.node_id(); + // Note(bmzhao): The expected functions saved by a resource object are: + // "_create_resource", "_initialize", and "_destroy_resource". + // https://github.com/tensorflow/tensorflow/blob/ad66f588c1666ade8051feb42811fa27b285271c/tensorflow/python/training/tracking/tracking.py#L277-L281 + if (child.local_name() == "_create_resource" && + obj_graph.nodes(child.node_id()).kind_case() == + SavedObject::kFunction) { + resource.create_resource = &objects->concrete_functions[child_node_id]; + } else if (child.local_name() == "_initialize" && + obj_graph.nodes(child.node_id()).kind_case() == + SavedObject::kFunction) { + resource.initialize = &objects->concrete_functions[child_node_id]; + } else if (child.local_name() == "_destroy_resource" && + obj_graph.nodes(child.node_id()).kind_case() == + SavedObject::kFunction) { + resource.destroy_resource = &objects->concrete_functions[child_node_id]; + } + } + } + + return Status(); +} + } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index 073d7072e9f..db45e28087f 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -27,10 +27,12 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" @@ -84,15 +86,22 @@ absl::optional FindNodeAtPath(StringPiece path, // Maps each node in `graphdef` to its corresponding Attribute Map. // Callers must ensure that `graphdef` outlives the returned map. -std::unordered_map -NodeToAttrMap(const tensorflow::GraphDef& graphdef); +gtl::FlatMap NodeToAttrMap( + const tensorflow::GraphDef& graphdef); // Maps the name of each FunctionDef in `library` to its corresponding // FunctionDef. Callers must ensure `library` outlives the returned map. -std::unordered_map +gtl::FlatMap FunctionNameToFunctionDefMap(const FunctionDefLibrary& library); +// Walks through the SavedObjectGraph in metagraph, and restores all nodes +// (except "UserDefinedObjects") with their corresponding type in +// "PartiallyRevivedObjects". +Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph, + ImmediateExecutionContext* context, + const std::string& directory, + PartiallyRevivedObjects* objects); + } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index 0662482e538..6386d0dbb79 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -30,6 +29,9 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" @@ -37,7 +39,6 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/constants.h" -#include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" @@ -46,6 +47,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/errors.h" @@ -62,142 +64,15 @@ limitations under the License. namespace tensorflow { // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary -using FunctionDefMap = - std::unordered_map; - -// Maps from a Nodedef's name to its corresponding AttrValues, for a given -// Graphdef -using NodeAttrMap = - std::unordered_map; - -// Maps from Node ID to an "Revived Object" implementing -// "TensorHandleConvertible" -using RevivedObjectMap = - std::unordered_map>; +using FunctionDefMap = gtl::FlatMap; // Maps from a functiondef's name to the corresponding "TFConcreteFunction" -using ConcreteFunctionMap = - std::unordered_map>; +using FlatTensorFunctionMap = + gtl::FlatMap>; namespace { -Status ConstantFromSavedConstant( - ImmediateExecutionContext* ctx, - const tensorflow::SavedConstant& saved_constant, - const NodeAttrMap& node_attr_map, std::unique_ptr* output) { - const std::string& const_op_name = saved_constant.operation(); - const auto& node_name_and_attrs = node_attr_map.find(const_op_name); - if (node_name_and_attrs == node_attr_map.end()) { - return errors::FailedPrecondition( - "Unable to find Const operation with name'", const_op_name, - "' in SavedModel graphdef"); - } - const AttrValueMap* attrs = node_name_and_attrs->second; - const auto& attr_name_and_value = attrs->find("value"); - if (attr_name_and_value == attrs->end()) { - return errors::FailedPrecondition("Unable to find Const operation '", - const_op_name, "'s value attribute"); - } - const TensorProto& tensor_proto = attr_name_and_value->second.tensor(); - return internal::TensorProtoToConstant(ctx, tensor_proto, output); -} - -// Restores all non-function objects in the SavedModel's object graph. -// This function walks through the metagraph's saved object graph, and -// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and -// SavedResources. These are returned via the `out` parameter. -Status ReviveObjects( - const MetaGraphDef& metagraph, ImmediateExecutionContext* context, - const std::string& directory, - std::unordered_map>* - revived_objects) { - // This is needed to restore "Constant" nodes by looking up their - // "Value" attribute. - NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def()); - - // These are needed for creating "Assets", by looking up their filenames. - std::vector assets; - TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(metagraph, &assets)); - - // Iterate through all the saved objects, restoring objects as we go. - // We don't recreate functions until all other objects have been created. - for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) { - const SavedObject& node = metagraph.object_graph_def().nodes(i); - if (node.kind_case() == SavedObject::kVariable) { - std::unique_ptr variable; - TF_RETURN_IF_ERROR( - internal::LoadSavedVariable(context, node.variable(), &variable)); - (*revived_objects)[i] = std::move(variable); - } else if (node.kind_case() == SavedObject::kConstant) { - std::unique_ptr constant; - TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(), - node_attr_map, &constant)); - (*revived_objects)[i] = std::move(constant); - } else if (node.kind_case() == SavedObject::kAsset) { - std::unique_ptr asset; - TF_RETURN_IF_ERROR(internal::LoadSavedAsset(context, node.asset(), - directory, assets, &asset)); - (*revived_objects)[i] = std::move(asset); - } else if (node.kind_case() == SavedObject::kResource) { - // TODO(bmzhao): Figure out how resource loading works and implement it - return errors::Unimplemented( - "SavedResource loading is not implemented yet"); - } - } - return Status(); -} - -Status ReviveFunctions(const MetaGraphDef& metagraph, - const RevivedObjectMap& revived_objects, - ImmediateExecutionContext* context, - ConcreteFunctionMap* restored_functions) { - const FunctionDefMap function_def_map = - internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library()); - - // Iterate through all objects, only examining functions. - for (const SavedObject& node : metagraph.object_graph_def().nodes()) { - if (node.kind_case() == SavedObject::kBareConcreteFunction) { - const std::string& function_name = - node.bare_concrete_function().concrete_function_name(); - - const SavedConcreteFunction& saved_concrete_function = - metagraph.object_graph_def().concrete_functions().at(function_name); - - const FunctionDef* function_def = function_def_map.at(function_name); - std::unique_ptr concrete_function; - TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction( - saved_concrete_function, function_def, revived_objects, context, - &concrete_function)); - (*restored_functions)[function_name] = std::move(concrete_function); - } else if (node.kind_case() == SavedObject::kFunction) { - // We only allow loading functions that have an annotated input signature, - // which means there is 1:1 correspondence between tf.function - // <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is - // the same restriction that MLIR has: - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707 - const SavedFunction& saved_function = node.function(); - if (saved_function.concrete_functions_size() != 1) { - return errors::FailedPrecondition( - "Only tf.functions annotated with an input signature are supported " - "by SavedModelAPI. This means that there should only be a single " - "ConcreteFunction per tf.function"); - } - const std::string& function_name = saved_function.concrete_functions(0); - const SavedConcreteFunction& saved_concrete_function = - metagraph.object_graph_def().concrete_functions().at(function_name); - - const FunctionDef* function_def = function_def_map.at(function_name); - - std::unique_ptr concrete_function; - TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction( - saved_concrete_function, function_def, revived_objects, context, - &concrete_function)); - (*restored_functions)[function_name] = std::move(concrete_function); - } - } - return Status(); -} const TrackableObjectGraph::TrackableObject::SerializedTensor* FindSerializedTensorInTrackable( @@ -234,7 +109,7 @@ FindSerializedTensorInTrackable( // overridden "restore" method: // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85 Status RestoreCheckpoint(SavedModelV2Bundle* bundle, - const RevivedObjectMap& revived_objects, + const RevivedObjects& revived_objects, const std::string& directory, ImmediateExecutionContext* context) { // TODO(bmzhao): Batch up all the restores into a single restore op per @@ -254,8 +129,7 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle, return Status::OK(); } - Variable* variable = - down_cast(revived_objects.at(node).get()); + Variable* variable = revived_objects.variables.at(node).get(); // Restore the tensor's value from the checkpoint const TrackableObjectGraph::TrackableObject::SerializedTensor* @@ -289,6 +163,14 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle, return Status(); } +Status InitializeAllResources(const RevivedObjects& revived) { + for (const auto& node_and_resource : revived.restored_resources) { + const RestoredResource& resource = node_and_resource.second; + TF_RETURN_IF_ERROR(resource.Initialize()); + } + return Status(); +} + } // namespace Status TFSavedModelAPI::GetFunction(const std::string& function_path, @@ -299,20 +181,12 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path, return errors::NotFound("No saved object found at path ", function_path); } - const SavedObject& object = bundle_.saved_object_graph().nodes(*node); - if (object.kind_case() == SavedObject::kBareConcreteFunction) { - *function = - concrete_functions_ - .at(object.bare_concrete_function().concrete_function_name()) - .get(); - } else if (object.kind_case() == SavedObject::kFunction) { - *function = - concrete_functions_.at(object.function().concrete_functions(0)).get(); - } else { - return errors::InvalidArgument(function_path, - " is not a path to a Function."); + auto function_iter = revived_objects_.concrete_functions.find(*node); + if (function_iter == revived_objects_.concrete_functions.end()) { + return errors::NotFound("No function found at path ", function_path); } + *function = function_iter->second.get(); return Status(); } @@ -325,8 +199,8 @@ Status TFSavedModelAPI::GetSignatureDefFunction( std::vector TFSavedModelAPI::ListFunctions() { std::vector result; - result.reserve(concrete_functions_.size()); - for (auto& index_and_function : concrete_functions_) { + result.reserve(revived_objects_.concrete_functions.size()); + for (auto& index_and_function : revived_objects_.concrete_functions) { result.push_back(index_and_function.second.get()); } return result; @@ -340,34 +214,21 @@ Status TFSavedModelAPI::GetVariable(const std::string& variable_path, return errors::NotFound("No saved object found at path ", variable_path); } - const SavedObject& object = bundle_.saved_object_graph().nodes(*node); - - if (object.kind_case() == SavedObject::kVariable) { - auto iter = revived_objects_.find(*node); - if (iter == revived_objects_.end()) { - return errors::Internal("Variable ", variable_path, - " was not properly revived."); - } - *variable = static_cast(iter->second.get()); - return Status(); + auto variables_iter = revived_objects_.variables.find(*node); + if (variables_iter == revived_objects_.variables.end()) { + return errors::NotFound("No variable found at path ", variable_path); } - *variable = nullptr; - return errors::InvalidArgument( - variable_path, " is not a path to a Variable (kind=", object.kind_case(), - ")"); + *variable = variables_iter->second.get(); + return Status(); } -TFSavedModelAPI::TFSavedModelAPI( - const std::string& directory, SavedModelV2Bundle bundle, - std::unordered_map> - revived_objects, - std::unordered_map> - concrete_functions) +TFSavedModelAPI::TFSavedModelAPI(const std::string& directory, + SavedModelV2Bundle bundle, + RevivedObjects revived_objects) : directory_(directory), bundle_(std::move(bundle)), - revived_objects_(std::move(revived_objects)), - concrete_functions_(std::move(concrete_functions)) {} + revived_objects_(std::move(revived_objects)) {} Status TFSavedModelAPI::Load( const std::string& directory, @@ -388,28 +249,25 @@ Status TFSavedModelAPI::Load( // This occurs in python here: // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454 - RevivedObjectMap revived_objects; - TF_RETURN_IF_ERROR(ReviveObjects(bundle.meta_graph_def(), context, directory, - &revived_objects)); + // Step 1: For each node in the graph, we should initialize an object of the + // corresponding type. For objects that depend on the initialization of other + // objects (like functions which capture resources), we will initialize them + // in step 2. + PartiallyRevivedObjects partially_revived_objects; + TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects( + bundle.meta_graph_def(), context, directory, &partially_revived_objects)); - // TODO(bmzhao): When we later add support for loading resources, we need to - // handle the case where materializing a function's captures requires invoking - // other functions. This occurs when retrieving the resource handle for a - // TrackableResource: - // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240 - // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233 - // This requires restoring functions in a topological sort order by capture - // dependencies. - ConcreteFunctionMap function_map; - TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects, - context, &function_map)); + RevivedObjects revived_objects; + TF_RETURN_IF_ERROR(partially_revived_objects.Build( + context, bundle.saved_object_graph(), &revived_objects)); TF_RETURN_IF_ERROR( RestoreCheckpoint(&bundle, revived_objects, directory, context)); + TF_RETURN_IF_ERROR(InitializeAllResources(revived_objects)); + out->reset(new TFSavedModelAPI(directory, std::move(bundle), - std::move(revived_objects), - std::move(function_map))); + std::move(revived_objects))); return Status(); } diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h index d108b4071e9..bc39a974ad2 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" @@ -72,19 +73,12 @@ class TFSavedModelAPI : public SavedModelAPI { Status GetVariable(const std::string& variable_path, Variable** variable); private: - TFSavedModelAPI( - const std::string& directory, SavedModelV2Bundle bundle, - std::unordered_map> - revived_objects, - std::unordered_map> - concrete_functions); + TFSavedModelAPI(const std::string& directory, SavedModelV2Bundle bundle, + RevivedObjects revived_objects); std::string directory_; SavedModelV2Bundle bundle_; - std::unordered_map> - revived_objects_; - std::unordered_map> - concrete_functions_; + RevivedObjects revived_objects_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index a55f232795b..ec5d65ea60d 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/tstring.h" @@ -196,6 +197,142 @@ TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) { TFE_DeleteContext(ctx); } +TEST_P(CSavedModelAPITest, LoadsStaticHashtableSavedModel) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = SavedModelPath("StaticHashTableModule"); + + TF_SavedModel* saved_model = + TF_LoadSavedModel(model_dir.c_str(), ctx, status); + + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TF_ConcreteFunction* lookup_fn = + TF_GetSavedModelConcreteFunction(saved_model, "lookup", status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // Note(bmzhao): Based on static_hashtable_asset.txt, we expect the following + // mapping: + // "foo" -> 0 + // "bar" -> 1 + // "baz" -> 2 + // "wombat" -> 3 + // all other strings -> -1 + + // Call lookup function with input "foo", expecting an output of 0 + { + std::vector lookup_fn_inputs; + TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("foo")); + lookup_fn_inputs.push_back(input_foo); + + TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp( + lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + tensorflow::int64* output_value = + static_cast(TF_TensorData(result)); + EXPECT_EQ(*output_value, 0); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(input_foo); + TFE_DeleteTensorHandle(lookup_fn_outputs[0]); + TFE_DeleteOp(lookup_op); + } + + // Call lookup function with input "baz", expecting an output of 2 + { + std::vector lookup_fn_inputs; + TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("baz")); + lookup_fn_inputs.push_back(input_foo); + + TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp( + lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + tensorflow::int64* output_value = + static_cast(TF_TensorData(result)); + EXPECT_EQ(*output_value, 2); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(input_foo); + TFE_DeleteTensorHandle(lookup_fn_outputs[0]); + TFE_DeleteOp(lookup_op); + } + + // Call lookup function w/input "NON-EXISTENT-KEY", expecting an output of -1 + { + std::vector lookup_fn_inputs; + TFE_TensorHandle* input_foo = + TestScalarTensorHandle(ctx, tstring("NON-EXISTENT-KEY")); + lookup_fn_inputs.push_back(input_foo); + + TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp( + lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + tensorflow::int64* output_value = + static_cast(TF_TensorData(result)); + EXPECT_EQ(*output_value, -1); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(input_foo); + TFE_DeleteTensorHandle(lookup_fn_outputs[0]); + TFE_DeleteOp(lookup_op); + } + + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 4efe56ff822..243f86ed787 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -213,6 +213,7 @@ py_binary( srcs = ["testdata/generate_saved_models.py"], data = [ ":saved_model_asset_data", + ":saved_model_static_hashtable_asset_data", ], python_version = "PY3", srcs_version = "PY3", @@ -220,6 +221,7 @@ py_binary( "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", "//tensorflow/python:tensor_spec", "//tensorflow/python:variables", "//tensorflow/python/compat:v2_compat", @@ -243,6 +245,7 @@ filegroup( "testdata/half_plus_two_v2/**", "testdata/x_plus_y_v2_debuginfo/**", "testdata/CyclicModule/**", + "testdata/StaticHashTableModule/**", "testdata/VarsAndArithmeticObjectGraph/**", "testdata/fuzz_generated/**", ]), @@ -260,6 +263,13 @@ filegroup( ], ) +filegroup( + name = "saved_model_static_hashtable_asset_data", + srcs = [ + "testdata/static_hashtable_asset.txt", + ], +) + exports_files( glob([ "testdata/half_plus_two_pbtxt/**", diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt new file mode 100644 index 00000000000..e79f591665f --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt @@ -0,0 +1,4 @@ +foo +bar +baz +wombat diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb new file mode 100644 index 00000000000..04e8ba62bdb Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..f6d62d9a51c Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index new file mode 100644 index 00000000000..df6c85e5783 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index differ diff --git a/tensorflow/cc/saved_model/testdata/generate_saved_models.py b/tensorflow/cc/saved_model/testdata/generate_saved_models.py index 91c09d33bb0..2b64cf52096 100644 --- a/tensorflow/cc/saved_model/testdata/generate_saved_models.py +++ b/tensorflow/cc/saved_model/testdata/generate_saved_models.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.module import module from tensorflow.python.ops import io_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.saved_model import save_options @@ -82,10 +83,31 @@ class AssetModule(module.Module): return io_ops.read_file(self.asset) +class StaticHashTableModule(module.Module): + """A module with an Asset, StaticHashTable, and a lookup function.""" + + def __init__(self): + self.asset = tracking.Asset( + test.test_src_dir_path( + "cc/saved_model/testdata/static_hashtable_asset.txt")) + self.table = lookup_ops.StaticHashTable( + lookup_ops.TextFileInitializer(self.asset, dtypes.string, + lookup_ops.TextFileIndex.WHOLE_LINE, + dtypes.int64, + lookup_ops.TextFileIndex.LINE_NUMBER), + -1) + + @def_function.function( + input_signature=[tensor_spec.TensorSpec(shape=None, dtype=dtypes.string)]) + def lookup(self, word): + return self.table.lookup(word) + + MODULE_CTORS = { "VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph, "CyclicModule": CyclicModule, "AssetModule": AssetModule, + "StaticHashTableModule": StaticHashTableModule, } diff --git a/tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt b/tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt new file mode 100644 index 00000000000..e79f591665f --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt @@ -0,0 +1,4 @@ +foo +bar +baz +wombat