Refactoring the guts of SavedModel loading to experimentally support loading resources that do not override _gather_saveables_for_checkpoint. This allows us to support simple resources like StaticHashTable. Note that additional logic related to checkpoints will need to be added to support df6b21c13c.

PiperOrigin-RevId: 333221978
Change-Id: Ib724b6cee9a57bf1b3ee98d2a8eaf9f394a8d64d
This commit is contained in:
Brian Zhao 2020-09-22 21:45:45 -07:00 committed by TensorFlower Gardener
parent bb24a4b88f
commit e143ada4a6
22 changed files with 1418 additions and 211 deletions

View File

@ -66,8 +66,13 @@ cc_library(
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:asset",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects",
"//tensorflow/c/experimental/saved_model/core/revived_types:restored_resource_revival_state",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function_revival_state",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_signature_def_function_revival_state",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/cc/saved_model:loader_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -147,12 +152,14 @@ cc_library(
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:flat_tensor_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects",
"//tensorflow/c/experimental/saved_model/core/revived_types:revived_objects",
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/cc/saved_model:loader_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -69,6 +69,84 @@ cc_library(
],
)
cc_library(
name = "partially_revived_objects",
srcs = [
"partially_revived_objects.cc",
],
hdrs = [
"partially_revived_objects.h",
],
deps = [
":asset",
":constant",
":restored_resource",
":restored_resource_revival_state",
":revived_objects",
":tf_concrete_function",
":tf_concrete_function_revival_state",
":tf_signature_def_function",
":tf_signature_def_function_revival_state",
":variable",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "restored_resource",
srcs = [
"restored_resource.cc",
],
hdrs = [
"restored_resource.h",
],
deps = [
":tensorhandle_convertible",
":tf_concrete_function",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "restored_resource_revival_state",
hdrs = [
"restored_resource_revival_state.h",
],
deps = [
":tf_concrete_function_revival_state",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)
cc_library(
name = "revived_objects",
hdrs = [
"revived_objects.h",
],
deps = [
":asset",
":constant",
":restored_resource",
":tf_concrete_function",
":tf_signature_def_function",
":variable",
"//tensorflow/core:lib",
],
)
cc_library(
name = "variable",
srcs = [
@ -123,6 +201,21 @@ cc_library(
],
)
cc_library(
name = "tf_concrete_function_revival_state",
hdrs = [
"tf_concrete_function_revival_state.h",
],
deps = [
":tf_concrete_function",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "tf_signature_def_function",
srcs = [
@ -145,3 +238,17 @@ cc_library(
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tf_signature_def_function_revival_state",
hdrs = [
"tf_signature_def_function_revival_state.h",
],
deps = [
":tf_signature_def_function",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -0,0 +1,388 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
#include <memory>
#include <utility>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace {
Status AssertAllCreateResourceFunctionsHaveNoCaptures(
const PartiallyRevivedObjects& objects) {
for (const auto& id_and_resource : objects.restored_resources) {
int node_id = id_and_resource.first;
const RestoredResourceRevivalState& resource = id_and_resource.second;
const TFConcreteFunctionRevivalState* create_resource_fn =
resource.create_resource;
if (create_resource_fn == nullptr) {
return errors::FailedPrecondition(
"Resource at node ", node_id,
" did not have a create_resource() function");
}
const SavedConcreteFunction* saved_create_resource_fn =
create_resource_fn->saved_concrete_func;
if (!saved_create_resource_fn->bound_inputs().empty()) {
// TODO(b/124045874): Support loading resource functions via a top sort
return errors::Unimplemented(
"Create Resource functions with captures are currently unsupported.");
}
}
return Status();
}
// Retrieves the TensorHandle associated with `node_id` from `obj_graph`, and
// set `*handle` to point to it.
Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
const PartiallyRevivedObjects& objects,
ImmediateExecutionTensorHandle** handle) {
const SavedObject& node = obj_graph.nodes(node_id);
SavedObject::KindCase kind = node.kind_case();
switch (kind) {
case SavedObject::kVariable: {
const auto& variables_iter = objects.variables.find(node_id);
if (variables_iter == objects.variables.end()) {
return errors::FailedPrecondition(
"Tried to convert node id ", node_id,
" of type variable to tensor but the variable wasn't initialized");
}
*handle = variables_iter->second->handle();
return Status();
}
case SavedObject::kConstant: {
const auto& constants_iter = objects.constants.find(node_id);
if (constants_iter == objects.constants.end()) {
return errors::FailedPrecondition("Tried to convert node id ", node_id,
" of type constant to tensor but the "
"constant wasn't initialized");
}
*handle = constants_iter->second->handle();
return Status();
}
case SavedObject::kAsset: {
const auto& assets_iter = objects.assets.find(node_id);
if (assets_iter == objects.assets.end()) {
return errors::FailedPrecondition(
"Tried to convert node id ", node_id,
" of type asset to tensor but the asset wasn't initialized");
}
*handle = assets_iter->second->handle();
return Status();
}
case SavedObject::kResource: {
const auto& resource_iter = objects.restored_resources.find(node_id);
if (resource_iter == objects.restored_resources.end()) {
return errors::FailedPrecondition(
"Tried to convert node id ", node_id,
" of type Resource to tensor but the Resource wasn't initialized");
}
const RestoredResourceRevivalState& resource = resource_iter->second;
if (resource.resource_handle == nullptr) {
return errors::FailedPrecondition(
"Resource with node id ", node_id,
" should have its resource_handle created, but was nullptr.");
}
*handle = resource.resource_handle.get();
return Status();
}
default: {
return errors::FailedPrecondition(
"Only objects of type variable, constant, asset, and resources have "
"capturable tensorhandles. Encountered object of kind ",
node.kind_case(), " at node id: ", node_id);
}
}
}
// This function finds the necessary captures, then forwards to the builder
// method
Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
const TFConcreteFunctionRevivalState& builder,
const SavedObjectGraph& obj_graph,
const PartiallyRevivedObjects& objects,
std::unique_ptr<TFConcreteFunction>* out) {
const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs();
std::vector<ImmediateExecutionTensorHandle*> captures;
captures.reserve(capture_node_ids.size());
for (int capture_node_id : capture_node_ids) {
ImmediateExecutionTensorHandle* capture_handle;
TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects,
&capture_handle));
captures.push_back(capture_handle);
}
// TODO(bmzhao): Create Metadata here
return TFConcreteFunction::Create(/*function_def=*/builder.fdef,
/*captures=*/std::move(captures),
/*metadata=*/{},
/*ctx=*/ctx,
/*out=*/out);
}
Status CreateSignatureDefFunction(
ImmediateExecutionContext* ctx,
const TFSignatureDefFunctionRevivalState& builder,
const SavedObjectGraph& obj_graph, const PartiallyRevivedObjects& objects,
std::unique_ptr<TFSignatureDefFunction>* out) {
const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs();
std::vector<ImmediateExecutionTensorHandle*> captures;
captures.reserve(capture_node_ids.size());
for (int capture_node_id : capture_node_ids) {
ImmediateExecutionTensorHandle* capture_handle;
TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects,
&capture_handle));
captures.push_back(capture_handle);
}
// TODO(bmzhao): Create Metadata here
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
/*captures=*/std::move(captures),
/*metadata=*/{},
/*ctx=*/ctx,
/*out=*/out);
}
Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx,
const SavedObjectGraph& obj_graph,
const PartiallyRevivedObjects& objects,
RevivedObjects* revived) {
for (const auto& id_and_resource : objects.restored_resources) {
const RestoredResourceRevivalState& resource = id_and_resource.second;
const TFConcreteFunctionRevivalState* create_resource_fn =
resource.create_resource;
const SavedConcreteFunction* saved_create_resource_fn =
create_resource_fn->saved_concrete_func;
if (!saved_create_resource_fn->bound_inputs().empty()) {
// TODO(b/124045874): Load resource functions via a topological sort
return errors::Unimplemented(
"Create Resource functions with captures are currently unsupported.");
}
std::unique_ptr<TFConcreteFunction> out;
TF_RETURN_IF_ERROR(CreateConcreteFunction(ctx, *create_resource_fn,
obj_graph, objects, &out));
revived->concrete_functions[create_resource_fn->node_id] = std::move(out);
}
return Status();
}
Status InitializeAllFunctions(ImmediateExecutionContext* ctx,
const SavedObjectGraph& obj_graph,
const PartiallyRevivedObjects& objects,
RevivedObjects* revived) {
gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>* destination_func_map =
&revived->concrete_functions;
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>*
destination_sig_map = &revived->signature_def_functions;
for (const auto& id_and_func : objects.concrete_functions) {
int node_id = id_and_func.first;
const TFConcreteFunctionRevivalState& func = id_and_func.second;
if (destination_func_map->find(node_id) != destination_func_map->end()) {
// The function has already been initialized in the destination_map,
// so we can skip this node. This can occur because we initialize
// CreateResource functions before calling this function.
continue;
}
std::unique_ptr<TFConcreteFunction> out;
TF_RETURN_IF_ERROR(
CreateConcreteFunction(ctx, func, obj_graph, objects, &out));
(*destination_func_map)[node_id] = std::move(out);
}
for (const auto& id_and_func : objects.signature_def_functions) {
int node_id = id_and_func.first;
const TFSignatureDefFunctionRevivalState& func = id_and_func.second;
if (destination_sig_map->find(node_id) != destination_sig_map->end()) {
continue;
}
std::unique_ptr<TFSignatureDefFunction> out;
TF_RETURN_IF_ERROR(
CreateSignatureDefFunction(ctx, func, obj_graph, objects, &out));
(*destination_sig_map)[node_id] = std::move(out);
}
return Status();
}
Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
const SavedObjectGraph& obj_graph,
PartiallyRevivedObjects* objects,
RevivedObjects* revived) {
for (auto& id_and_resource : objects->restored_resources) {
RestoredResourceRevivalState& resource = id_and_resource.second;
int create_resource_fn_node = resource.create_resource->node_id;
const gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>&
revived_functions = revived->concrete_functions;
const auto& revived_functions_iter =
revived_functions.find(create_resource_fn_node);
if (revived_functions_iter == revived_functions.end()) {
return errors::FailedPrecondition(
"ConcreteFunction at node ", create_resource_fn_node,
" should have been initialized prior to being called.");
}
const TFConcreteFunction& create_resource_fn =
*revived_functions_iter->second;
ImmediateOpPtr function_op;
TF_RETURN_IF_ERROR(create_resource_fn.MakeCallOp({}, &function_op));
TF_RETURN_IF_ERROR(function_op->SetDeviceName(resource.device.c_str()));
AbstractTensorHandle* resource_handle = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(function_op->Execute(
absl::MakeSpan(&resource_handle, num_retvals), &num_retvals));
AbstractTensorHandlePtr owned_resource_handle(resource_handle);
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
owned_resource_handle.get())) {
return errors::Internal("Unexpected tensor handle kind.");
}
ImmediateTensorHandlePtr result(
reinterpret_cast<ImmediateExecutionTensorHandle*>(
owned_resource_handle.release()));
resource.resource_handle = std::move(result);
}
return Status();
}
// Finds a ConcreteFunction with node id `node` in `objects`, and sets *out to
// point to it. If node doesn't exist in `objects`, out is untouched, and an
// error status is returned.
Status FindConcreteFunction(int node, RevivedObjects* objects,
TFConcreteFunction** out) {
auto func_iter = objects->concrete_functions.find(node);
if (func_iter == objects->concrete_functions.end()) {
return errors::FailedPrecondition(
"Failed to find ConcreteFunction with node id ", node,
" in revived objects");
}
*out = func_iter->second.get();
return Status();
}
Status BuildResources(ImmediateExecutionContext* ctx,
const SavedObjectGraph& obj_graph,
PartiallyRevivedObjects* objects,
RevivedObjects* revived) {
for (auto& id_and_resource : objects->restored_resources) {
int node_id = id_and_resource.first;
RestoredResourceRevivalState& resource_revival_state =
id_and_resource.second;
TFConcreteFunction* create_resource = nullptr;
// Check all the functions associated with the resource have already been
// initialized in `revived`
if (resource_revival_state.create_resource != nullptr) {
TF_RETURN_IF_ERROR(
FindConcreteFunction(resource_revival_state.create_resource->node_id,
revived, &create_resource));
}
TFConcreteFunction* initialize = nullptr;
if (resource_revival_state.initialize != nullptr) {
TF_RETURN_IF_ERROR(FindConcreteFunction(
resource_revival_state.initialize->node_id, revived, &initialize));
}
TFConcreteFunction* destroy_resource = nullptr;
if (resource_revival_state.destroy_resource != nullptr) {
TF_RETURN_IF_ERROR(
FindConcreteFunction(resource_revival_state.destroy_resource->node_id,
revived, &destroy_resource));
}
if (resource_revival_state.resource_handle == nullptr) {
return errors::FailedPrecondition("Resource at node id ", node_id,
" does not have a resource handle.");
}
revived->restored_resources.emplace(
node_id, RestoredResource(
/*device=*/resource_revival_state.device,
/*create_resource=*/create_resource,
/*initialize=*/initialize,
/*destroy_resource=*/destroy_resource,
/*resource_handle=*/
std::move(resource_revival_state.resource_handle)));
}
return Status();
}
} // namespace
Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx,
const SavedObjectGraph& obj_graph,
RevivedObjects* revived) {
// Step 1: We would like to initialize all functions; this requires setting up
// their captured tensorhandles, which may come from variables, assets,
// constants, or resources. The first three are trivial; However,
// tensorhandles that correspond to resources must be created by invoking
// their "create_resource" function.
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
// For now, we assert that all create_resource functions must have no
// captures. This aligns with the current behavior in python.
// https://github.com/tensorflow/tensorflow/blob/50eac986bf7a0ad12594e080f083181f277e0b49/tensorflow/python/saved_model/load.py#L152-L155
// TODO(bmzhao): We should do a topological sort instead.
// 1a. Make sure all CreateResource functions have no captures.
TF_RETURN_IF_ERROR(AssertAllCreateResourceFunctionsHaveNoCaptures(*this));
// 1b. Initialize all CreateResource functions, storing them in `revived`
TF_RETURN_IF_ERROR(
InitializeCreateResourceFunctions(ctx, obj_graph, *this, revived));
// 1c. Invoke all "CreateResource" functions and store their ResourceHandles
// https://github.com/tensorflow/tensorflow/blob/3b6b41b68a95dc70c26dc816b29d359bfb88c116/tensorflow/python/training/tracking/tracking.py#L241-L247
// in *this->resources.
// TODO(bmzhao): Maybe store them separately, not in *this?
TF_RETURN_IF_ERROR(CreateAllResourceHandles(ctx, obj_graph, this, revived));
// 2. Initialize all the rest of the functions
TF_RETURN_IF_ERROR(InitializeAllFunctions(ctx, obj_graph, *this, revived));
// 3a. Move over all non-function, non-resource objects
revived->variables = std::move(variables);
revived->assets = std::move(assets);
revived->constants = std::move(constants);
// 3b. Move over resources.
TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,54 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
// Container for objects during the revival step in SavedModel's loading.
// Notably, resources and functions can be in a state where they reference
// other resources/functions that have not been constructed yet. We collect
// *all* objects in a partially valid state here, then properly initialize
// resources and functions.
struct PartiallyRevivedObjects {
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
gtl::FlatMap<int, std::unique_ptr<Constant>> constants;
gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions;
gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions;
gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources;
Status Build(ImmediateExecutionContext* ctx,
const SavedObjectGraph& obj_graph, RevivedObjects* revived);
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
Status ExecuteNoArgDummyReturnFunction(TFConcreteFunction* func) {
ImmediateOpPtr function_op;
TF_RETURN_IF_ERROR(func->MakeCallOp({}, &function_op));
AbstractTensorHandle* dummy_output = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(function_op->Execute(
absl::MakeSpan(&dummy_output, num_retvals), &num_retvals));
AbstractTensorHandlePtr owned_dummy_output(dummy_output);
return Status();
}
} // namespace
RestoredResource::RestoredResource(const std::string& device,
TFConcreteFunction* create_resource,
TFConcreteFunction* initialize,
TFConcreteFunction* destroy_resource,
ImmediateTensorHandlePtr resource_handle)
: TensorHandleConvertible(std::move(resource_handle)),
device_(device),
create_resource_(create_resource),
initialize_(initialize),
destroy_resource_(destroy_resource) {}
Status RestoredResource::Initialize() const {
return ExecuteNoArgDummyReturnFunction(initialize_);
}
RestoredResource::~RestoredResource() {
// Note(bmzhao): SavedModels saved before
// https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3
// did not have their destroy_resource function saved, meaning they will
// leak resources.
if (destroy_resource_ != nullptr) {
Status status = ExecuteNoArgDummyReturnFunction(destroy_resource_);
if (!status.ok()) {
LOG(WARNING)
<< "Failed executing destroy_resource function for RestoredResource: "
<< status.error_message();
}
}
}
} // namespace tensorflow

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_
#include <memory>
#include <string>
#include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
namespace tensorflow {
// RestoredResource represents a TF2 "Resource" object loaded from a savedmodel,
// analogous to the Python _RestoredResource object:
// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/saved_model/load.py#L481
// TF2 resource objects typically extend TrackableResource:
// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/training/tracking/tracking.py#L285
// and are expected to implement "_create_resource", "_initialize", and
// "_destroy_resource" functions:
// https://github.com/tensorflow/tensorflow/blob/139ba9c5284799beafdd1d7f895127cf00e7c48f/tensorflow/python/training/tracking/tracking.py#L262-L281
class RestoredResource : TensorHandleConvertible {
public:
// Note(bmzhao): RestoredResource stores non-owning pointers to its associated
// functions because SavedModel internally owns all functions and objects in
// the RevivedObjects struct (which owns all functions). One alternative would
// be to have RevivedObjects store shared_ptr<TFConcreteFunction> instead, and
// change RestoredResource's constructor take shared_ptr<TFConcreteFunction>.
// To keep things simple, I've stuck to raw pointers for now.
//
// Params:
// device - The device string associated with the SavedResource
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_object_graph.proto#L182
// Conceptually, this is the same device used in CapturableResource:
// https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L222-L225
// Implementation-wise, it is device used when invoking the
// create_resource function to produce the resource_handle
// associated with the object:
// https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L246-L247
// create_resource - Non owning pointer to the create_resource function
// associated with this object. Must be NON-NULL.
// initialize - Non owning pointer to the initialize function associated with
// this object. Must be NON-NULL.
// destroy_resource - Non owning pointer to the destroy_resource function
// associated with this object. Ideally this should be
// NON-NULL, but in order to support models saved prior to
// https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3
// we allow null here. This will, however, leak resources.
RestoredResource(const std::string& device,
TFConcreteFunction* create_resource,
TFConcreteFunction* initialize,
TFConcreteFunction* destroy_resource,
ImmediateTensorHandlePtr resource_handle);
Status Initialize() const;
// RestoredResource is movable, but not copyable.
RestoredResource(RestoredResource&& other) = default;
RestoredResource& operator=(RestoredResource&& other) = default;
~RestoredResource() override;
private:
std::string device_;
TFConcreteFunction* create_resource_;
TFConcreteFunction* initialize_;
TFConcreteFunction* destroy_resource_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_
#include <string>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
namespace tensorflow {
// All "Resources" should have these 3 saved functions:
// https://github.com/tensorflow/tensorflow/blob/86dc281333d7d277ddc1882f2bca4b17e7ec40e5/tensorflow/python/training/tracking/tracking.py#L277-L281
struct RestoredResourceRevivalState {
std::string device;
TFConcreteFunctionRevivalState* create_resource = nullptr;
TFConcreteFunctionRevivalState* initialize = nullptr;
TFConcreteFunctionRevivalState* destroy_resource = nullptr;
ImmediateTensorHandlePtr resource_handle = nullptr;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_
#include <memory>
#include <unordered_map>
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow {
// RevivedObjects is mainly used as a container for all the "state" owned by
// SavedModel. It stores all non-"user object" nodes from a SavedModel
// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62)
// in a "fully constructed" state. It is effectively a strongly typed map, where
// each member is a map from the node id in the SavedObjectGraph's nodes
// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29)
// to the revived object of the corresponding type.
struct RevivedObjects {
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
gtl::FlatMap<int, std::unique_ptr<Constant>> constants;
gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>> concrete_functions;
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
signature_def_functions;
gtl::FlatMap<int, RestoredResource> restored_resources;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_

View File

@ -0,0 +1,61 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
// TFConcreteFunctionRevivalState wraps the state needed for building a
// TF_ConcreteFunction. This is mainly used in PartiallyRevivedObjects, which
// wraps partially constructed Function and Resource objects.
struct TFConcreteFunctionRevivalState {
// Index of the node in the SavedObjectGraph it was loaded from.
int node_id;
// Pointer to the original functiondef. fdef_ is guaranteed to be
// non-null.
const FunctionDef* fdef;
// TensorHandle captures for this funtion
std::vector<ImmediateExecutionTensorHandle*> captures;
// SavedConcreteFunction contains much of the metadata of the expected "types"
// of the inputs and outputs of a function.
// Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null.
const SavedConcreteFunction* saved_concrete_func;
// This field is only present on TF2 ConcreteFunctions, and is useful for
// determining the original argument *names* of the function, (since the
// "canonicalized_input_signature" may append extra uniquifying integers).
// However, SavedBareConcreteFunctions do not have a FunctionSpec.
// Note(bmzhao): if function_spec_.has_value(), *function_spec_ is guaranteed
// to be non-null.
absl::optional<const FunctionSpec*> function_spec;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
// FunctionBuilder wraps the state needed for building a SignatureDefFunction.
// This is mainly used in PartiallyRevivedObjects, which wraps partially
// constructed Function and Resource objects.
struct TFSignatureDefFunctionRevivalState {
// Index of the node in the SavedObjectGraph it was loaded from.
int node_id = 0;
// Pointer to the original functiondef. fdef_ is guaranteed to be
// non-null.
const FunctionDef* fdef = nullptr;
// SavedConcreteFunction contains much of the metadata of the expected "types"
// of the inputs and outputs of a function.
// Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null.
const SavedConcreteFunction* saved_concrete_func = nullptr;
// The name of the SignatureDef key.
std::string signature_key;
// TensorHandle captures for this funtion
std::vector<ImmediateExecutionTensorHandle*> captures;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_

View File

@ -17,15 +17,22 @@ limitations under the License.
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "absl/strings/str_split.h"
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
@ -41,6 +48,83 @@ namespace {
using StructuredValueDictEntry =
protobuf::MapPair<std::string, StructuredValue>;
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
// Graphdef
using NodeAttrMap =
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher>;
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
StringPieceHasher>;
// Looks up a SavedConstant's associated tensorproto from the NodeAttrMap and
// returns a tensorflow::Constant.
Status ConstantFromSavedConstant(
ImmediateExecutionContext* ctx,
const tensorflow::SavedConstant& saved_constant,
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
const std::string& const_op_name = saved_constant.operation();
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
if (node_name_and_attrs == node_attr_map.end()) {
return errors::FailedPrecondition(
"Unable to find Const operation with name'", const_op_name,
"' in SavedModel graphdef");
}
const AttrValueMap* attrs = node_name_and_attrs->second;
const auto& attr_name_and_value = attrs->find("value");
if (attr_name_and_value == attrs->end()) {
return errors::FailedPrecondition("Unable to find Const operation '",
const_op_name, "'s value attribute");
}
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
}
// Finds the "signatures" object in the object graph, and fills a mapping of
// each signature's name to the corresponding function's node in the object
// graph.
Status GetSignaturesMap(const SavedObjectGraph& saved_objects,
gtl::FlatMap<std::string, int>* signatures_map) {
if (saved_objects.nodes().empty()) {
return errors::FailedPrecondition("Saved Object Graph was empty.");
}
const SavedObject& root = saved_objects.nodes(0);
const SavedObject* signatures = nullptr;
for (const auto& child : root.children()) {
if (child.local_name() == "signatures") {
if (child.node_id() >= saved_objects.nodes().size()) {
return errors::FailedPrecondition(
"Signature object had child node id ", child.node_id(),
" which exceeds the size of the set of nodes");
}
signatures = &saved_objects.nodes(child.node_id());
}
}
// Some basic sanity checks that this object is actually our "signatures" map
if (signatures == nullptr) {
// This is where the "signatures" attribute is always set:
// https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109
return errors::FailedPrecondition(
"SavedObjectGraph's root object must have a child 'signatures' object");
}
if (signatures->kind_case() != SavedObject::kUserObject) {
return errors::FailedPrecondition(
"Signatures must be a SavedObject of type UserObject.");
}
if (signatures->user_object().identifier() != "signature_map") {
// This is where the string comes from:
// https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220
return errors::FailedPrecondition(
"Signatures SavedObject must have identifier 'signature_map'.");
}
for (const auto& child : signatures->children()) {
(*signatures_map)[child.local_name()] = child.node_id();
}
return Status();
}
// Perform some basic sanity checks on SavedConcreteFunction's input and
// output signatures with respect to the corresponding FunctionDef's input
// and output args.
@ -99,6 +183,21 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
return Status();
}
Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) {
// We only allow loading functions that have an annotated input signature,
// which means there is 1:1 correspondence between tf.function
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
// the same restriction that MLIR has:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
if (saved_function.concrete_functions_size() != 1) {
return errors::FailedPrecondition(
"Only tf.functions annotated with an input signature are supported "
"by SavedModelAPI. This means that there should only be a single "
"ConcreteFunction per tf.function");
}
return Status();
}
} // namespace
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
@ -255,21 +354,18 @@ absl::optional<int> FindNodeAtPath(StringPiece path,
return node_id;
}
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>
NodeToAttrMap(const tensorflow::GraphDef& graphdef) {
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>
result;
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap(
const tensorflow::GraphDef& graphdef) {
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> result;
for (const tensorflow::NodeDef& node : graphdef.node()) {
result[node.name()] = &node.attr();
}
return result;
}
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
StringPieceHasher>
gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) {
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
StringPieceHasher>
gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
result;
for (const FunctionDef& function_def : library.function()) {
result[function_def.signature().name()] = &function_def;
@ -277,5 +373,154 @@ FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) {
return result;
}
Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
ImmediateExecutionContext* context,
const std::string& directory,
PartiallyRevivedObjects* objects) {
// This is needed to restore "Constant" nodes by looking up their
// "Value" attribute.
NodeAttrMap node_attr_map = NodeToAttrMap(metagraph.graph_def());
// These are needed for creating "Assets", by looking up their filenames.
std::vector<AssetFileDef> assets;
TF_RETURN_IF_ERROR(GetAssetFileDefs(metagraph, &assets));
// Signatures are needed for determining whether a function is a
// SignatureDefFunction or not.
gtl::FlatMap<std::string, int> signatures_map;
TF_RETURN_IF_ERROR(
GetSignaturesMap(metagraph.object_graph_def(), &signatures_map));
gtl::FlatMap<int, std::string> reversed_signatures_map;
reversed_signatures_map.reserve(signatures_map.size());
for (const auto& signature_key_and_node : signatures_map) {
reversed_signatures_map.emplace(signature_key_and_node.second,
signature_key_and_node.first);
}
// FunctionDefs are needed to help construct
// TFConcreteFunction/SignatureDefFunctions
const FunctionDefMap function_def_map =
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
// Iterate through all the saved objects, restoring objects (if we can) as we
// go. For objects that dependencies on other objects (resources/functions),
// we partially initialize "builders" that correspond to their currently known
// state, and gradually fill them out in subsequent passes.
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
const SavedObject& node = metagraph.object_graph_def().nodes(i);
if (node.kind_case() == SavedObject::kVariable) {
std::unique_ptr<Variable> variable;
TF_RETURN_IF_ERROR(
LoadSavedVariable(context, node.variable(), &variable));
objects->variables[i] = std::move(variable);
} else if (node.kind_case() == SavedObject::kConstant) {
std::unique_ptr<Constant> constant;
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
node_attr_map, &constant));
objects->constants[i] = std::move(constant);
} else if (node.kind_case() == SavedObject::kAsset) {
std::unique_ptr<Asset> asset;
TF_RETURN_IF_ERROR(
LoadSavedAsset(context, node.asset(), directory, assets, &asset));
objects->assets[i] = std::move(asset);
} else if (node.kind_case() == SavedObject::kResource) {
RestoredResourceRevivalState resource_revival_state;
// We'll set the resource's functions in a subsequent pass, once we get
// all functions in a partially revived state.
resource_revival_state.device = node.resource().device();
objects->restored_resources[i] = std::move(resource_revival_state);
} else if (node.kind_case() == SavedObject::kFunction) {
// Get the SavedFunction node and validate it has a single concrete func.
const SavedFunction& saved_function = node.function();
TF_RETURN_IF_ERROR(ValidateSingleConcreteFunction(saved_function));
// Retrieve related function information.
const std::string& function_name = saved_function.concrete_functions(0);
const FunctionDef* function_def = function_def_map.at(function_name);
const SavedConcreteFunction& saved_concrete_func =
metagraph.object_graph_def().concrete_functions().at(function_name);
const FunctionSpec& function_spec = saved_function.function_spec();
// Construct either a SignatureDefFunctionBuilder or a
// ConcreteFunctionBuilder, depending on whether this node was a child
// of the "signatures" attribute from root object.
auto reverse_signature_iter = reversed_signatures_map.find(i);
if (reverse_signature_iter != reversed_signatures_map.end()) {
TFSignatureDefFunctionRevivalState func_revival_state;
func_revival_state.node_id = i;
func_revival_state.fdef = function_def;
func_revival_state.saved_concrete_func = &saved_concrete_func;
func_revival_state.signature_key = reverse_signature_iter->second;
objects->signature_def_functions[i] = std::move(func_revival_state);
} else {
TFConcreteFunctionRevivalState func_revival_state;
func_revival_state.node_id = i;
func_revival_state.fdef = function_def;
func_revival_state.saved_concrete_func = &saved_concrete_func;
func_revival_state.function_spec = &function_spec;
objects->concrete_functions[i] = std::move(func_revival_state);
}
} else if (node.kind_case() == SavedObject::kBareConcreteFunction) {
const SavedBareConcreteFunction& bare_cf = node.bare_concrete_function();
// Retrieve related function information.
const std::string& function_name = bare_cf.concrete_function_name();
const FunctionDef* function_def = function_def_map.at(function_name);
const SavedConcreteFunction& saved_concrete_func =
metagraph.object_graph_def().concrete_functions().at(function_name);
// Check whether this is a SignatureDefFunction, or not.
auto reverse_signature_iter = reversed_signatures_map.find(i);
if (reverse_signature_iter != reversed_signatures_map.end()) {
TFSignatureDefFunctionRevivalState func_revival_state;
func_revival_state.node_id = i;
func_revival_state.fdef = function_def;
func_revival_state.saved_concrete_func = &saved_concrete_func;
func_revival_state.signature_key = reverse_signature_iter->second;
objects->signature_def_functions[i] = std::move(func_revival_state);
} else {
TFConcreteFunctionRevivalState func_revival_state;
func_revival_state.node_id = i;
func_revival_state.fdef = function_def;
func_revival_state.saved_concrete_func = &saved_concrete_func;
objects->concrete_functions[i] = std::move(func_revival_state);
}
}
}
// Now that we've partially restored all functions, we can have resources
// point to them
for (auto& node_and_resource_revival_state : objects->restored_resources) {
int node_id = node_and_resource_revival_state.first;
const SavedObjectGraph& obj_graph = metagraph.object_graph_def();
const SavedObject& node = obj_graph.nodes(node_id);
RestoredResourceRevivalState& resource =
node_and_resource_revival_state.second;
for (const TrackableObjectGraph::TrackableObject::ObjectReference& child :
node.children()) {
int child_node_id = child.node_id();
// Note(bmzhao): The expected functions saved by a resource object are:
// "_create_resource", "_initialize", and "_destroy_resource".
// https://github.com/tensorflow/tensorflow/blob/ad66f588c1666ade8051feb42811fa27b285271c/tensorflow/python/training/tracking/tracking.py#L277-L281
if (child.local_name() == "_create_resource" &&
obj_graph.nodes(child.node_id()).kind_case() ==
SavedObject::kFunction) {
resource.create_resource = &objects->concrete_functions[child_node_id];
} else if (child.local_name() == "_initialize" &&
obj_graph.nodes(child.node_id()).kind_case() ==
SavedObject::kFunction) {
resource.initialize = &objects->concrete_functions[child_node_id];
} else if (child.local_name() == "_destroy_resource" &&
obj_graph.nodes(child.node_id()).kind_case() ==
SavedObject::kFunction) {
resource.destroy_resource = &objects->concrete_functions[child_node_id];
}
}
}
return Status();
}
} // namespace internal
} // namespace tensorflow

View File

@ -27,10 +27,12 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
@ -84,15 +86,22 @@ absl::optional<int> FindNodeAtPath(StringPiece path,
// Maps each node in `graphdef` to its corresponding Attribute Map.
// Callers must ensure that `graphdef` outlives the returned map.
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>
NodeToAttrMap(const tensorflow::GraphDef& graphdef);
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap(
const tensorflow::GraphDef& graphdef);
// Maps the name of each FunctionDef in `library` to its corresponding
// FunctionDef. Callers must ensure `library` outlives the returned map.
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
StringPieceHasher>
gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
FunctionNameToFunctionDefMap(const FunctionDefLibrary& library);
// Walks through the SavedObjectGraph in metagraph, and restores all nodes
// (except "UserDefinedObjects") with their corresponding type in
// "PartiallyRevivedObjects".
Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
ImmediateExecutionContext* context,
const std::string& directory,
PartiallyRevivedObjects* objects);
} // namespace internal
} // namespace tensorflow

View File

@ -17,7 +17,6 @@ limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@ -30,6 +29,9 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
@ -37,7 +39,6 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -46,6 +47,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
@ -62,142 +64,15 @@ limitations under the License.
namespace tensorflow {
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
using FunctionDefMap =
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
StringPieceHasher>;
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
// Graphdef
using NodeAttrMap =
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>;
// Maps from Node ID to an "Revived Object" implementing
// "TensorHandleConvertible"
using RevivedObjectMap =
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>;
using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
StringPieceHasher>;
// Maps from a functiondef's name to the corresponding "TFConcreteFunction"
using ConcreteFunctionMap =
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>;
using FlatTensorFunctionMap =
gtl::FlatMap<std::string, std::unique_ptr<FlatTensorFunction>>;
namespace {
Status ConstantFromSavedConstant(
ImmediateExecutionContext* ctx,
const tensorflow::SavedConstant& saved_constant,
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
const std::string& const_op_name = saved_constant.operation();
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
if (node_name_and_attrs == node_attr_map.end()) {
return errors::FailedPrecondition(
"Unable to find Const operation with name'", const_op_name,
"' in SavedModel graphdef");
}
const AttrValueMap* attrs = node_name_and_attrs->second;
const auto& attr_name_and_value = attrs->find("value");
if (attr_name_and_value == attrs->end()) {
return errors::FailedPrecondition("Unable to find Const operation '",
const_op_name, "'s value attribute");
}
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
}
// Restores all non-function objects in the SavedModel's object graph.
// This function walks through the metagraph's saved object graph, and
// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and
// SavedResources. These are returned via the `out` parameter.
Status ReviveObjects(
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
const std::string& directory,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
revived_objects) {
// This is needed to restore "Constant" nodes by looking up their
// "Value" attribute.
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
// These are needed for creating "Assets", by looking up their filenames.
std::vector<AssetFileDef> assets;
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(metagraph, &assets));
// Iterate through all the saved objects, restoring objects as we go.
// We don't recreate functions until all other objects have been created.
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
const SavedObject& node = metagraph.object_graph_def().nodes(i);
if (node.kind_case() == SavedObject::kVariable) {
std::unique_ptr<Variable> variable;
TF_RETURN_IF_ERROR(
internal::LoadSavedVariable(context, node.variable(), &variable));
(*revived_objects)[i] = std::move(variable);
} else if (node.kind_case() == SavedObject::kConstant) {
std::unique_ptr<Constant> constant;
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
node_attr_map, &constant));
(*revived_objects)[i] = std::move(constant);
} else if (node.kind_case() == SavedObject::kAsset) {
std::unique_ptr<Asset> asset;
TF_RETURN_IF_ERROR(internal::LoadSavedAsset(context, node.asset(),
directory, assets, &asset));
(*revived_objects)[i] = std::move(asset);
} else if (node.kind_case() == SavedObject::kResource) {
// TODO(bmzhao): Figure out how resource loading works and implement it
return errors::Unimplemented(
"SavedResource loading is not implemented yet");
}
}
return Status();
}
Status ReviveFunctions(const MetaGraphDef& metagraph,
const RevivedObjectMap& revived_objects,
ImmediateExecutionContext* context,
ConcreteFunctionMap* restored_functions) {
const FunctionDefMap function_def_map =
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
// Iterate through all objects, only examining functions.
for (const SavedObject& node : metagraph.object_graph_def().nodes()) {
if (node.kind_case() == SavedObject::kBareConcreteFunction) {
const std::string& function_name =
node.bare_concrete_function().concrete_function_name();
const SavedConcreteFunction& saved_concrete_function =
metagraph.object_graph_def().concrete_functions().at(function_name);
const FunctionDef* function_def = function_def_map.at(function_name);
std::unique_ptr<TFConcreteFunction> concrete_function;
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
saved_concrete_function, function_def, revived_objects, context,
&concrete_function));
(*restored_functions)[function_name] = std::move(concrete_function);
} else if (node.kind_case() == SavedObject::kFunction) {
// We only allow loading functions that have an annotated input signature,
// which means there is 1:1 correspondence between tf.function
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
// the same restriction that MLIR has:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
const SavedFunction& saved_function = node.function();
if (saved_function.concrete_functions_size() != 1) {
return errors::FailedPrecondition(
"Only tf.functions annotated with an input signature are supported "
"by SavedModelAPI. This means that there should only be a single "
"ConcreteFunction per tf.function");
}
const std::string& function_name = saved_function.concrete_functions(0);
const SavedConcreteFunction& saved_concrete_function =
metagraph.object_graph_def().concrete_functions().at(function_name);
const FunctionDef* function_def = function_def_map.at(function_name);
std::unique_ptr<TFConcreteFunction> concrete_function;
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
saved_concrete_function, function_def, revived_objects, context,
&concrete_function));
(*restored_functions)[function_name] = std::move(concrete_function);
}
}
return Status();
}
const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(
@ -234,7 +109,7 @@ FindSerializedTensorInTrackable(
// overridden "restore" method:
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
const RevivedObjectMap& revived_objects,
const RevivedObjects& revived_objects,
const std::string& directory,
ImmediateExecutionContext* context) {
// TODO(bmzhao): Batch up all the restores into a single restore op per
@ -254,8 +129,7 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
return Status::OK();
}
Variable* variable =
down_cast<Variable*>(revived_objects.at(node).get());
Variable* variable = revived_objects.variables.at(node).get();
// Restore the tensor's value from the checkpoint
const TrackableObjectGraph::TrackableObject::SerializedTensor*
@ -289,6 +163,14 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
return Status();
}
Status InitializeAllResources(const RevivedObjects& revived) {
for (const auto& node_and_resource : revived.restored_resources) {
const RestoredResource& resource = node_and_resource.second;
TF_RETURN_IF_ERROR(resource.Initialize());
}
return Status();
}
} // namespace
Status TFSavedModelAPI::GetFunction(const std::string& function_path,
@ -299,20 +181,12 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path,
return errors::NotFound("No saved object found at path ", function_path);
}
const SavedObject& object = bundle_.saved_object_graph().nodes(*node);
if (object.kind_case() == SavedObject::kBareConcreteFunction) {
*function =
concrete_functions_
.at(object.bare_concrete_function().concrete_function_name())
.get();
} else if (object.kind_case() == SavedObject::kFunction) {
*function =
concrete_functions_.at(object.function().concrete_functions(0)).get();
} else {
return errors::InvalidArgument(function_path,
" is not a path to a Function.");
auto function_iter = revived_objects_.concrete_functions.find(*node);
if (function_iter == revived_objects_.concrete_functions.end()) {
return errors::NotFound("No function found at path ", function_path);
}
*function = function_iter->second.get();
return Status();
}
@ -325,8 +199,8 @@ Status TFSavedModelAPI::GetSignatureDefFunction(
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
std::vector<ConcreteFunction*> result;
result.reserve(concrete_functions_.size());
for (auto& index_and_function : concrete_functions_) {
result.reserve(revived_objects_.concrete_functions.size());
for (auto& index_and_function : revived_objects_.concrete_functions) {
result.push_back(index_and_function.second.get());
}
return result;
@ -340,34 +214,21 @@ Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
return errors::NotFound("No saved object found at path ", variable_path);
}
const SavedObject& object = bundle_.saved_object_graph().nodes(*node);
if (object.kind_case() == SavedObject::kVariable) {
auto iter = revived_objects_.find(*node);
if (iter == revived_objects_.end()) {
return errors::Internal("Variable ", variable_path,
" was not properly revived.");
}
*variable = static_cast<Variable*>(iter->second.get());
return Status();
auto variables_iter = revived_objects_.variables.find(*node);
if (variables_iter == revived_objects_.variables.end()) {
return errors::NotFound("No variable found at path ", variable_path);
}
*variable = nullptr;
return errors::InvalidArgument(
variable_path, " is not a path to a Variable (kind=", object.kind_case(),
")");
*variable = variables_iter->second.get();
return Status();
}
TFSavedModelAPI::TFSavedModelAPI(
const std::string& directory, SavedModelV2Bundle bundle,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects,
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions)
TFSavedModelAPI::TFSavedModelAPI(const std::string& directory,
SavedModelV2Bundle bundle,
RevivedObjects revived_objects)
: directory_(directory),
bundle_(std::move(bundle)),
revived_objects_(std::move(revived_objects)),
concrete_functions_(std::move(concrete_functions)) {}
revived_objects_(std::move(revived_objects)) {}
Status TFSavedModelAPI::Load(
const std::string& directory,
@ -388,28 +249,25 @@ Status TFSavedModelAPI::Load(
// This occurs in python here:
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
RevivedObjectMap revived_objects;
TF_RETURN_IF_ERROR(ReviveObjects(bundle.meta_graph_def(), context, directory,
&revived_objects));
// Step 1: For each node in the graph, we should initialize an object of the
// corresponding type. For objects that depend on the initialization of other
// objects (like functions which capture resources), we will initialize them
// in step 2.
PartiallyRevivedObjects partially_revived_objects;
TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects(
bundle.meta_graph_def(), context, directory, &partially_revived_objects));
// TODO(bmzhao): When we later add support for loading resources, we need to
// handle the case where materializing a function's captures requires invoking
// other functions. This occurs when retrieving the resource handle for a
// TrackableResource:
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
// This requires restoring functions in a topological sort order by capture
// dependencies.
ConcreteFunctionMap function_map;
TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects,
context, &function_map));
RevivedObjects revived_objects;
TF_RETURN_IF_ERROR(partially_revived_objects.Build(
context, bundle.saved_object_graph(), &revived_objects));
TF_RETURN_IF_ERROR(
RestoreCheckpoint(&bundle, revived_objects, directory, context));
TF_RETURN_IF_ERROR(InitializeAllResources(revived_objects));
out->reset(new TFSavedModelAPI(directory, std::move(bundle),
std::move(revived_objects),
std::move(function_map)));
std::move(revived_objects)));
return Status();
}

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
@ -72,19 +73,12 @@ class TFSavedModelAPI : public SavedModelAPI {
Status GetVariable(const std::string& variable_path, Variable** variable);
private:
TFSavedModelAPI(
const std::string& directory, SavedModelV2Bundle bundle,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects,
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions);
TFSavedModelAPI(const std::string& directory, SavedModelV2Bundle bundle,
RevivedObjects revived_objects);
std::string directory_;
SavedModelV2Bundle bundle_;
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects_;
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions_;
RevivedObjects revived_objects_;
};
} // namespace tensorflow

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/tstring.h"
@ -196,6 +197,142 @@ TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
TFE_DeleteContext(ctx);
}
TEST_P(CSavedModelAPITest, LoadsStaticHashtableSavedModel) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
bool use_tfrt = GetParam();
if (use_tfrt) {
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::string model_dir = SavedModelPath("StaticHashTableModule");
TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_ConcreteFunction* lookup_fn =
TF_GetSavedModelConcreteFunction(saved_model, "lookup", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// Note(bmzhao): Based on static_hashtable_asset.txt, we expect the following
// mapping:
// "foo" -> 0
// "bar" -> 1
// "baz" -> 2
// "wombat" -> 3
// all other strings -> -1
// Call lookup function with input "foo", expecting an output of 0
{
std::vector<TFE_TensorHandle*> lookup_fn_inputs;
TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("foo"));
lookup_fn_inputs.push_back(input_foo);
TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
// inputs + outputs a function has.
TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_NumDims(result), 0);
tensorflow::int64* output_value =
static_cast<tensorflow::int64*>(TF_TensorData(result));
EXPECT_EQ(*output_value, 0);
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(input_foo);
TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
TFE_DeleteOp(lookup_op);
}
// Call lookup function with input "baz", expecting an output of 2
{
std::vector<TFE_TensorHandle*> lookup_fn_inputs;
TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("baz"));
lookup_fn_inputs.push_back(input_foo);
TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
// inputs + outputs a function has.
TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_NumDims(result), 0);
tensorflow::int64* output_value =
static_cast<tensorflow::int64*>(TF_TensorData(result));
EXPECT_EQ(*output_value, 2);
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(input_foo);
TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
TFE_DeleteOp(lookup_op);
}
// Call lookup function w/input "NON-EXISTENT-KEY", expecting an output of -1
{
std::vector<TFE_TensorHandle*> lookup_fn_inputs;
TFE_TensorHandle* input_foo =
TestScalarTensorHandle(ctx, tstring("NON-EXISTENT-KEY"));
lookup_fn_inputs.push_back(input_foo);
TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
// inputs + outputs a function has.
TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_NumDims(result), 0);
tensorflow::int64* output_value =
static_cast<tensorflow::int64*>(TF_TensorData(result));
EXPECT_EQ(*output_value, -1);
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(input_foo);
TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
TFE_DeleteOp(lookup_op);
}
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();

View File

@ -213,6 +213,7 @@ py_binary(
srcs = ["testdata/generate_saved_models.py"],
data = [
":saved_model_asset_data",
":saved_model_static_hashtable_asset_data",
],
python_version = "PY3",
srcs_version = "PY3",
@ -220,6 +221,7 @@ py_binary(
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:variables",
"//tensorflow/python/compat:v2_compat",
@ -243,6 +245,7 @@ filegroup(
"testdata/half_plus_two_v2/**",
"testdata/x_plus_y_v2_debuginfo/**",
"testdata/CyclicModule/**",
"testdata/StaticHashTableModule/**",
"testdata/VarsAndArithmeticObjectGraph/**",
"testdata/fuzz_generated/**",
]),
@ -260,6 +263,13 @@ filegroup(
],
)
filegroup(
name = "saved_model_static_hashtable_asset_data",
srcs = [
"testdata/static_hashtable_asset.txt",
],
)
exports_files(
glob([
"testdata/half_plus_two_pbtxt/**",

View File

@ -0,0 +1,4 @@
foo
bar
baz
wombat

View File

@ -30,6 +30,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.module import module
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save_options
@ -82,10 +83,31 @@ class AssetModule(module.Module):
return io_ops.read_file(self.asset)
class StaticHashTableModule(module.Module):
"""A module with an Asset, StaticHashTable, and a lookup function."""
def __init__(self):
self.asset = tracking.Asset(
test.test_src_dir_path(
"cc/saved_model/testdata/static_hashtable_asset.txt"))
self.table = lookup_ops.StaticHashTable(
lookup_ops.TextFileInitializer(self.asset, dtypes.string,
lookup_ops.TextFileIndex.WHOLE_LINE,
dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER),
-1)
@def_function.function(
input_signature=[tensor_spec.TensorSpec(shape=None, dtype=dtypes.string)])
def lookup(self, word):
return self.table.lookup(word)
MODULE_CTORS = {
"VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph,
"CyclicModule": CyclicModule,
"AssetModule": AssetModule,
"StaticHashTableModule": StaticHashTableModule,
}

View File

@ -0,0 +1,4 @@
foo
bar
baz
wombat