Refactoring the guts of SavedModel loading to experimentally support loading resources that do not override _gather_saveables_for_checkpoint. This allows us to support simple resources like StaticHashTable. Note that additional logic related to checkpoints will need to be added to support df6b21c13c
.
PiperOrigin-RevId: 333221978 Change-Id: Ib724b6cee9a57bf1b3ee98d2a8eaf9f394a8d64d
This commit is contained in:
parent
bb24a4b88f
commit
e143ada4a6
tensorflow
c/experimental/saved_model
core
BUILD
revived_types
BUILDpartially_revived_objects.ccpartially_revived_objects.hrestored_resource.ccrestored_resource.hrestored_resource_revival_state.hrevived_objects.htf_concrete_function_revival_state.htf_signature_def_function_revival_state.h
saved_model_utils.ccsaved_model_utils.htf_saved_model_api.cctf_saved_model_api.hinternal
cc/saved_model
@ -66,8 +66,13 @@ cc_library(
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:asset",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:restored_resource_revival_state",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function_revival_state",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_signature_def_function_revival_state",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/cc/saved_model:loader_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -147,12 +152,14 @@ cc_library(
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:flat_tensor_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:partially_revived_objects",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:revived_objects",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/cc/saved_model:bundle_v2",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/cc/saved_model:loader_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -69,6 +69,84 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "partially_revived_objects",
|
||||
srcs = [
|
||||
"partially_revived_objects.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"partially_revived_objects.h",
|
||||
],
|
||||
deps = [
|
||||
":asset",
|
||||
":constant",
|
||||
":restored_resource",
|
||||
":restored_resource_revival_state",
|
||||
":revived_objects",
|
||||
":tf_concrete_function",
|
||||
":tf_concrete_function_revival_state",
|
||||
":tf_signature_def_function",
|
||||
":tf_signature_def_function_revival_state",
|
||||
":variable",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "restored_resource",
|
||||
srcs = [
|
||||
"restored_resource.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"restored_resource.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
":tf_concrete_function",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "restored_resource_revival_state",
|
||||
hdrs = [
|
||||
"restored_resource_revival_state.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_concrete_function_revival_state",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "revived_objects",
|
||||
hdrs = [
|
||||
"revived_objects.h",
|
||||
],
|
||||
deps = [
|
||||
":asset",
|
||||
":constant",
|
||||
":restored_resource",
|
||||
":tf_concrete_function",
|
||||
":tf_signature_def_function",
|
||||
":variable",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable",
|
||||
srcs = [
|
||||
@ -123,6 +201,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_concrete_function_revival_state",
|
||||
hdrs = [
|
||||
"tf_concrete_function_revival_state.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_concrete_function",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_signature_def_function",
|
||||
srcs = [
|
||||
@ -145,3 +238,17 @@ cc_library(
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_signature_def_function_revival_state",
|
||||
hdrs = [
|
||||
"tf_signature_def_function_revival_state.h",
|
||||
],
|
||||
deps = [
|
||||
":tf_signature_def_function",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,388 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
Status AssertAllCreateResourceFunctionsHaveNoCaptures(
|
||||
const PartiallyRevivedObjects& objects) {
|
||||
for (const auto& id_and_resource : objects.restored_resources) {
|
||||
int node_id = id_and_resource.first;
|
||||
const RestoredResourceRevivalState& resource = id_and_resource.second;
|
||||
const TFConcreteFunctionRevivalState* create_resource_fn =
|
||||
resource.create_resource;
|
||||
if (create_resource_fn == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"Resource at node ", node_id,
|
||||
" did not have a create_resource() function");
|
||||
}
|
||||
const SavedConcreteFunction* saved_create_resource_fn =
|
||||
create_resource_fn->saved_concrete_func;
|
||||
if (!saved_create_resource_fn->bound_inputs().empty()) {
|
||||
// TODO(b/124045874): Support loading resource functions via a top sort
|
||||
return errors::Unimplemented(
|
||||
"Create Resource functions with captures are currently unsupported.");
|
||||
}
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
// Retrieves the TensorHandle associated with `node_id` from `obj_graph`, and
|
||||
// set `*handle` to point to it.
|
||||
Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
|
||||
const PartiallyRevivedObjects& objects,
|
||||
ImmediateExecutionTensorHandle** handle) {
|
||||
const SavedObject& node = obj_graph.nodes(node_id);
|
||||
SavedObject::KindCase kind = node.kind_case();
|
||||
switch (kind) {
|
||||
case SavedObject::kVariable: {
|
||||
const auto& variables_iter = objects.variables.find(node_id);
|
||||
if (variables_iter == objects.variables.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Tried to convert node id ", node_id,
|
||||
" of type variable to tensor but the variable wasn't initialized");
|
||||
}
|
||||
*handle = variables_iter->second->handle();
|
||||
return Status();
|
||||
}
|
||||
case SavedObject::kConstant: {
|
||||
const auto& constants_iter = objects.constants.find(node_id);
|
||||
if (constants_iter == objects.constants.end()) {
|
||||
return errors::FailedPrecondition("Tried to convert node id ", node_id,
|
||||
" of type constant to tensor but the "
|
||||
"constant wasn't initialized");
|
||||
}
|
||||
*handle = constants_iter->second->handle();
|
||||
return Status();
|
||||
}
|
||||
case SavedObject::kAsset: {
|
||||
const auto& assets_iter = objects.assets.find(node_id);
|
||||
if (assets_iter == objects.assets.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Tried to convert node id ", node_id,
|
||||
" of type asset to tensor but the asset wasn't initialized");
|
||||
}
|
||||
*handle = assets_iter->second->handle();
|
||||
return Status();
|
||||
}
|
||||
case SavedObject::kResource: {
|
||||
const auto& resource_iter = objects.restored_resources.find(node_id);
|
||||
if (resource_iter == objects.restored_resources.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Tried to convert node id ", node_id,
|
||||
" of type Resource to tensor but the Resource wasn't initialized");
|
||||
}
|
||||
const RestoredResourceRevivalState& resource = resource_iter->second;
|
||||
if (resource.resource_handle == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"Resource with node id ", node_id,
|
||||
" should have its resource_handle created, but was nullptr.");
|
||||
}
|
||||
*handle = resource.resource_handle.get();
|
||||
return Status();
|
||||
}
|
||||
default: {
|
||||
return errors::FailedPrecondition(
|
||||
"Only objects of type variable, constant, asset, and resources have "
|
||||
"capturable tensorhandles. Encountered object of kind ",
|
||||
node.kind_case(), " at node id: ", node_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function finds the necessary captures, then forwards to the builder
|
||||
// method
|
||||
Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
|
||||
const TFConcreteFunctionRevivalState& builder,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
const PartiallyRevivedObjects& objects,
|
||||
std::unique_ptr<TFConcreteFunction>* out) {
|
||||
const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs();
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures;
|
||||
captures.reserve(capture_node_ids.size());
|
||||
for (int capture_node_id : capture_node_ids) {
|
||||
ImmediateExecutionTensorHandle* capture_handle;
|
||||
TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects,
|
||||
&capture_handle));
|
||||
captures.push_back(capture_handle);
|
||||
}
|
||||
// TODO(bmzhao): Create Metadata here
|
||||
return TFConcreteFunction::Create(/*function_def=*/builder.fdef,
|
||||
/*captures=*/std::move(captures),
|
||||
/*metadata=*/{},
|
||||
/*ctx=*/ctx,
|
||||
/*out=*/out);
|
||||
}
|
||||
|
||||
Status CreateSignatureDefFunction(
|
||||
ImmediateExecutionContext* ctx,
|
||||
const TFSignatureDefFunctionRevivalState& builder,
|
||||
const SavedObjectGraph& obj_graph, const PartiallyRevivedObjects& objects,
|
||||
std::unique_ptr<TFSignatureDefFunction>* out) {
|
||||
const auto& capture_node_ids = builder.saved_concrete_func->bound_inputs();
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures;
|
||||
captures.reserve(capture_node_ids.size());
|
||||
for (int capture_node_id : capture_node_ids) {
|
||||
ImmediateExecutionTensorHandle* capture_handle;
|
||||
TF_RETURN_IF_ERROR(TensorHandleFromNode(capture_node_id, obj_graph, objects,
|
||||
&capture_handle));
|
||||
captures.push_back(capture_handle);
|
||||
}
|
||||
// TODO(bmzhao): Create Metadata here
|
||||
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
|
||||
/*captures=*/std::move(captures),
|
||||
/*metadata=*/{},
|
||||
/*ctx=*/ctx,
|
||||
/*out=*/out);
|
||||
}
|
||||
|
||||
Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
const PartiallyRevivedObjects& objects,
|
||||
RevivedObjects* revived) {
|
||||
for (const auto& id_and_resource : objects.restored_resources) {
|
||||
const RestoredResourceRevivalState& resource = id_and_resource.second;
|
||||
const TFConcreteFunctionRevivalState* create_resource_fn =
|
||||
resource.create_resource;
|
||||
|
||||
const SavedConcreteFunction* saved_create_resource_fn =
|
||||
create_resource_fn->saved_concrete_func;
|
||||
if (!saved_create_resource_fn->bound_inputs().empty()) {
|
||||
// TODO(b/124045874): Load resource functions via a topological sort
|
||||
return errors::Unimplemented(
|
||||
"Create Resource functions with captures are currently unsupported.");
|
||||
}
|
||||
std::unique_ptr<TFConcreteFunction> out;
|
||||
TF_RETURN_IF_ERROR(CreateConcreteFunction(ctx, *create_resource_fn,
|
||||
obj_graph, objects, &out));
|
||||
revived->concrete_functions[create_resource_fn->node_id] = std::move(out);
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status InitializeAllFunctions(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
const PartiallyRevivedObjects& objects,
|
||||
RevivedObjects* revived) {
|
||||
gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>* destination_func_map =
|
||||
&revived->concrete_functions;
|
||||
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>*
|
||||
destination_sig_map = &revived->signature_def_functions;
|
||||
|
||||
for (const auto& id_and_func : objects.concrete_functions) {
|
||||
int node_id = id_and_func.first;
|
||||
const TFConcreteFunctionRevivalState& func = id_and_func.second;
|
||||
|
||||
if (destination_func_map->find(node_id) != destination_func_map->end()) {
|
||||
// The function has already been initialized in the destination_map,
|
||||
// so we can skip this node. This can occur because we initialize
|
||||
// CreateResource functions before calling this function.
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> out;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateConcreteFunction(ctx, func, obj_graph, objects, &out));
|
||||
(*destination_func_map)[node_id] = std::move(out);
|
||||
}
|
||||
|
||||
for (const auto& id_and_func : objects.signature_def_functions) {
|
||||
int node_id = id_and_func.first;
|
||||
const TFSignatureDefFunctionRevivalState& func = id_and_func.second;
|
||||
|
||||
if (destination_sig_map->find(node_id) != destination_sig_map->end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unique_ptr<TFSignatureDefFunction> out;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateSignatureDefFunction(ctx, func, obj_graph, objects, &out));
|
||||
(*destination_sig_map)[node_id] = std::move(out);
|
||||
}
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
PartiallyRevivedObjects* objects,
|
||||
RevivedObjects* revived) {
|
||||
for (auto& id_and_resource : objects->restored_resources) {
|
||||
RestoredResourceRevivalState& resource = id_and_resource.second;
|
||||
int create_resource_fn_node = resource.create_resource->node_id;
|
||||
const gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>>&
|
||||
revived_functions = revived->concrete_functions;
|
||||
|
||||
const auto& revived_functions_iter =
|
||||
revived_functions.find(create_resource_fn_node);
|
||||
if (revived_functions_iter == revived_functions.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"ConcreteFunction at node ", create_resource_fn_node,
|
||||
" should have been initialized prior to being called.");
|
||||
}
|
||||
const TFConcreteFunction& create_resource_fn =
|
||||
*revived_functions_iter->second;
|
||||
ImmediateOpPtr function_op;
|
||||
TF_RETURN_IF_ERROR(create_resource_fn.MakeCallOp({}, &function_op));
|
||||
TF_RETURN_IF_ERROR(function_op->SetDeviceName(resource.device.c_str()));
|
||||
|
||||
AbstractTensorHandle* resource_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(function_op->Execute(
|
||||
absl::MakeSpan(&resource_handle, num_retvals), &num_retvals));
|
||||
AbstractTensorHandlePtr owned_resource_handle(resource_handle);
|
||||
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
|
||||
owned_resource_handle.get())) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
ImmediateTensorHandlePtr result(
|
||||
reinterpret_cast<ImmediateExecutionTensorHandle*>(
|
||||
owned_resource_handle.release()));
|
||||
resource.resource_handle = std::move(result);
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
// Finds a ConcreteFunction with node id `node` in `objects`, and sets *out to
|
||||
// point to it. If node doesn't exist in `objects`, out is untouched, and an
|
||||
// error status is returned.
|
||||
Status FindConcreteFunction(int node, RevivedObjects* objects,
|
||||
TFConcreteFunction** out) {
|
||||
auto func_iter = objects->concrete_functions.find(node);
|
||||
if (func_iter == objects->concrete_functions.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Failed to find ConcreteFunction with node id ", node,
|
||||
" in revived objects");
|
||||
}
|
||||
*out = func_iter->second.get();
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status BuildResources(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
PartiallyRevivedObjects* objects,
|
||||
RevivedObjects* revived) {
|
||||
for (auto& id_and_resource : objects->restored_resources) {
|
||||
int node_id = id_and_resource.first;
|
||||
RestoredResourceRevivalState& resource_revival_state =
|
||||
id_and_resource.second;
|
||||
|
||||
TFConcreteFunction* create_resource = nullptr;
|
||||
|
||||
// Check all the functions associated with the resource have already been
|
||||
// initialized in `revived`
|
||||
if (resource_revival_state.create_resource != nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindConcreteFunction(resource_revival_state.create_resource->node_id,
|
||||
revived, &create_resource));
|
||||
}
|
||||
|
||||
TFConcreteFunction* initialize = nullptr;
|
||||
if (resource_revival_state.initialize != nullptr) {
|
||||
TF_RETURN_IF_ERROR(FindConcreteFunction(
|
||||
resource_revival_state.initialize->node_id, revived, &initialize));
|
||||
}
|
||||
|
||||
TFConcreteFunction* destroy_resource = nullptr;
|
||||
if (resource_revival_state.destroy_resource != nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindConcreteFunction(resource_revival_state.destroy_resource->node_id,
|
||||
revived, &destroy_resource));
|
||||
}
|
||||
|
||||
if (resource_revival_state.resource_handle == nullptr) {
|
||||
return errors::FailedPrecondition("Resource at node id ", node_id,
|
||||
" does not have a resource handle.");
|
||||
}
|
||||
|
||||
revived->restored_resources.emplace(
|
||||
node_id, RestoredResource(
|
||||
/*device=*/resource_revival_state.device,
|
||||
/*create_resource=*/create_resource,
|
||||
/*initialize=*/initialize,
|
||||
/*destroy_resource=*/destroy_resource,
|
||||
/*resource_handle=*/
|
||||
std::move(resource_revival_state.resource_handle)));
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph,
|
||||
RevivedObjects* revived) {
|
||||
// Step 1: We would like to initialize all functions; this requires setting up
|
||||
// their captured tensorhandles, which may come from variables, assets,
|
||||
// constants, or resources. The first three are trivial; However,
|
||||
// tensorhandles that correspond to resources must be created by invoking
|
||||
// their "create_resource" function.
|
||||
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
|
||||
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
|
||||
// For now, we assert that all create_resource functions must have no
|
||||
// captures. This aligns with the current behavior in python.
|
||||
// https://github.com/tensorflow/tensorflow/blob/50eac986bf7a0ad12594e080f083181f277e0b49/tensorflow/python/saved_model/load.py#L152-L155
|
||||
// TODO(bmzhao): We should do a topological sort instead.
|
||||
|
||||
// 1a. Make sure all CreateResource functions have no captures.
|
||||
TF_RETURN_IF_ERROR(AssertAllCreateResourceFunctionsHaveNoCaptures(*this));
|
||||
|
||||
// 1b. Initialize all CreateResource functions, storing them in `revived`
|
||||
TF_RETURN_IF_ERROR(
|
||||
InitializeCreateResourceFunctions(ctx, obj_graph, *this, revived));
|
||||
|
||||
// 1c. Invoke all "CreateResource" functions and store their ResourceHandles
|
||||
// https://github.com/tensorflow/tensorflow/blob/3b6b41b68a95dc70c26dc816b29d359bfb88c116/tensorflow/python/training/tracking/tracking.py#L241-L247
|
||||
// in *this->resources.
|
||||
// TODO(bmzhao): Maybe store them separately, not in *this?
|
||||
TF_RETURN_IF_ERROR(CreateAllResourceHandles(ctx, obj_graph, this, revived));
|
||||
|
||||
// 2. Initialize all the rest of the functions
|
||||
TF_RETURN_IF_ERROR(InitializeAllFunctions(ctx, obj_graph, *this, revived));
|
||||
|
||||
// 3a. Move over all non-function, non-resource objects
|
||||
revived->variables = std::move(variables);
|
||||
revived->assets = std::move(assets);
|
||||
revived->constants = std::move(constants);
|
||||
|
||||
// 3b. Move over resources.
|
||||
TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived));
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,54 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Container for objects during the revival step in SavedModel's loading.
|
||||
// Notably, resources and functions can be in a state where they reference
|
||||
// other resources/functions that have not been constructed yet. We collect
|
||||
// *all* objects in a partially valid state here, then properly initialize
|
||||
// resources and functions.
|
||||
struct PartiallyRevivedObjects {
|
||||
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
|
||||
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
|
||||
gtl::FlatMap<int, std::unique_ptr<Constant>> constants;
|
||||
gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions;
|
||||
gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions;
|
||||
gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources;
|
||||
|
||||
Status Build(ImmediateExecutionContext* ctx,
|
||||
const SavedObjectGraph& obj_graph, RevivedObjects* revived);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
Status ExecuteNoArgDummyReturnFunction(TFConcreteFunction* func) {
|
||||
ImmediateOpPtr function_op;
|
||||
TF_RETURN_IF_ERROR(func->MakeCallOp({}, &function_op));
|
||||
|
||||
AbstractTensorHandle* dummy_output = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(function_op->Execute(
|
||||
absl::MakeSpan(&dummy_output, num_retvals), &num_retvals));
|
||||
AbstractTensorHandlePtr owned_dummy_output(dummy_output);
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
RestoredResource::RestoredResource(const std::string& device,
|
||||
TFConcreteFunction* create_resource,
|
||||
TFConcreteFunction* initialize,
|
||||
TFConcreteFunction* destroy_resource,
|
||||
ImmediateTensorHandlePtr resource_handle)
|
||||
: TensorHandleConvertible(std::move(resource_handle)),
|
||||
device_(device),
|
||||
create_resource_(create_resource),
|
||||
initialize_(initialize),
|
||||
destroy_resource_(destroy_resource) {}
|
||||
|
||||
Status RestoredResource::Initialize() const {
|
||||
return ExecuteNoArgDummyReturnFunction(initialize_);
|
||||
}
|
||||
|
||||
RestoredResource::~RestoredResource() {
|
||||
// Note(bmzhao): SavedModels saved before
|
||||
// https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3
|
||||
// did not have their destroy_resource function saved, meaning they will
|
||||
// leak resources.
|
||||
if (destroy_resource_ != nullptr) {
|
||||
Status status = ExecuteNoArgDummyReturnFunction(destroy_resource_);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING)
|
||||
<< "Failed executing destroy_resource function for RestoredResource: "
|
||||
<< status.error_message();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// RestoredResource represents a TF2 "Resource" object loaded from a savedmodel,
|
||||
// analogous to the Python _RestoredResource object:
|
||||
// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/saved_model/load.py#L481
|
||||
// TF2 resource objects typically extend TrackableResource:
|
||||
// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/training/tracking/tracking.py#L285
|
||||
// and are expected to implement "_create_resource", "_initialize", and
|
||||
// "_destroy_resource" functions:
|
||||
// https://github.com/tensorflow/tensorflow/blob/139ba9c5284799beafdd1d7f895127cf00e7c48f/tensorflow/python/training/tracking/tracking.py#L262-L281
|
||||
class RestoredResource : TensorHandleConvertible {
|
||||
public:
|
||||
// Note(bmzhao): RestoredResource stores non-owning pointers to its associated
|
||||
// functions because SavedModel internally owns all functions and objects in
|
||||
// the RevivedObjects struct (which owns all functions). One alternative would
|
||||
// be to have RevivedObjects store shared_ptr<TFConcreteFunction> instead, and
|
||||
// change RestoredResource's constructor take shared_ptr<TFConcreteFunction>.
|
||||
// To keep things simple, I've stuck to raw pointers for now.
|
||||
//
|
||||
// Params:
|
||||
// device - The device string associated with the SavedResource
|
||||
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_object_graph.proto#L182
|
||||
// Conceptually, this is the same device used in CapturableResource:
|
||||
// https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L222-L225
|
||||
// Implementation-wise, it is device used when invoking the
|
||||
// create_resource function to produce the resource_handle
|
||||
// associated with the object:
|
||||
// https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L246-L247
|
||||
// create_resource - Non owning pointer to the create_resource function
|
||||
// associated with this object. Must be NON-NULL.
|
||||
// initialize - Non owning pointer to the initialize function associated with
|
||||
// this object. Must be NON-NULL.
|
||||
// destroy_resource - Non owning pointer to the destroy_resource function
|
||||
// associated with this object. Ideally this should be
|
||||
// NON-NULL, but in order to support models saved prior to
|
||||
// https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3
|
||||
// we allow null here. This will, however, leak resources.
|
||||
RestoredResource(const std::string& device,
|
||||
TFConcreteFunction* create_resource,
|
||||
TFConcreteFunction* initialize,
|
||||
TFConcreteFunction* destroy_resource,
|
||||
ImmediateTensorHandlePtr resource_handle);
|
||||
|
||||
Status Initialize() const;
|
||||
|
||||
// RestoredResource is movable, but not copyable.
|
||||
RestoredResource(RestoredResource&& other) = default;
|
||||
RestoredResource& operator=(RestoredResource&& other) = default;
|
||||
|
||||
~RestoredResource() override;
|
||||
|
||||
private:
|
||||
std::string device_;
|
||||
TFConcreteFunction* create_resource_;
|
||||
TFConcreteFunction* initialize_;
|
||||
TFConcreteFunction* destroy_resource_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_
|
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// All "Resources" should have these 3 saved functions:
|
||||
// https://github.com/tensorflow/tensorflow/blob/86dc281333d7d277ddc1882f2bca4b17e7ec40e5/tensorflow/python/training/tracking/tracking.py#L277-L281
|
||||
struct RestoredResourceRevivalState {
|
||||
std::string device;
|
||||
TFConcreteFunctionRevivalState* create_resource = nullptr;
|
||||
TFConcreteFunctionRevivalState* initialize = nullptr;
|
||||
TFConcreteFunctionRevivalState* destroy_resource = nullptr;
|
||||
ImmediateTensorHandlePtr resource_handle = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_
|
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// RevivedObjects is mainly used as a container for all the "state" owned by
|
||||
// SavedModel. It stores all non-"user object" nodes from a SavedModel
|
||||
// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62)
|
||||
// in a "fully constructed" state. It is effectively a strongly typed map, where
|
||||
// each member is a map from the node id in the SavedObjectGraph's nodes
|
||||
// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29)
|
||||
// to the revived object of the corresponding type.
|
||||
struct RevivedObjects {
|
||||
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
|
||||
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
|
||||
gtl::FlatMap<int, std::unique_ptr<Constant>> constants;
|
||||
gtl::FlatMap<int, std::unique_ptr<TFConcreteFunction>> concrete_functions;
|
||||
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
|
||||
signature_def_functions;
|
||||
gtl::FlatMap<int, RestoredResource> restored_resources;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_
|
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TFConcreteFunctionRevivalState wraps the state needed for building a
|
||||
// TF_ConcreteFunction. This is mainly used in PartiallyRevivedObjects, which
|
||||
// wraps partially constructed Function and Resource objects.
|
||||
struct TFConcreteFunctionRevivalState {
|
||||
// Index of the node in the SavedObjectGraph it was loaded from.
|
||||
int node_id;
|
||||
|
||||
// Pointer to the original functiondef. fdef_ is guaranteed to be
|
||||
// non-null.
|
||||
const FunctionDef* fdef;
|
||||
|
||||
// TensorHandle captures for this funtion
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures;
|
||||
|
||||
// SavedConcreteFunction contains much of the metadata of the expected "types"
|
||||
// of the inputs and outputs of a function.
|
||||
// Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null.
|
||||
const SavedConcreteFunction* saved_concrete_func;
|
||||
|
||||
// This field is only present on TF2 ConcreteFunctions, and is useful for
|
||||
// determining the original argument *names* of the function, (since the
|
||||
// "canonicalized_input_signature" may append extra uniquifying integers).
|
||||
// However, SavedBareConcreteFunctions do not have a FunctionSpec.
|
||||
// Note(bmzhao): if function_spec_.has_value(), *function_spec_ is guaranteed
|
||||
// to be non-null.
|
||||
absl::optional<const FunctionSpec*> function_spec;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// FunctionBuilder wraps the state needed for building a SignatureDefFunction.
|
||||
// This is mainly used in PartiallyRevivedObjects, which wraps partially
|
||||
// constructed Function and Resource objects.
|
||||
struct TFSignatureDefFunctionRevivalState {
|
||||
// Index of the node in the SavedObjectGraph it was loaded from.
|
||||
int node_id = 0;
|
||||
|
||||
// Pointer to the original functiondef. fdef_ is guaranteed to be
|
||||
// non-null.
|
||||
const FunctionDef* fdef = nullptr;
|
||||
|
||||
// SavedConcreteFunction contains much of the metadata of the expected "types"
|
||||
// of the inputs and outputs of a function.
|
||||
// Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null.
|
||||
const SavedConcreteFunction* saved_concrete_func = nullptr;
|
||||
|
||||
// The name of the SignatureDef key.
|
||||
std::string signature_key;
|
||||
|
||||
// TensorHandle captures for this funtion
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_
|
@ -17,15 +17,22 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -41,6 +48,83 @@ namespace {
|
||||
using StructuredValueDictEntry =
|
||||
protobuf::MapPair<std::string, StructuredValue>;
|
||||
|
||||
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
|
||||
// Graphdef
|
||||
using NodeAttrMap =
|
||||
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher>;
|
||||
|
||||
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
|
||||
using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>;
|
||||
|
||||
// Looks up a SavedConstant's associated tensorproto from the NodeAttrMap and
|
||||
// returns a tensorflow::Constant.
|
||||
Status ConstantFromSavedConstant(
|
||||
ImmediateExecutionContext* ctx,
|
||||
const tensorflow::SavedConstant& saved_constant,
|
||||
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
|
||||
const std::string& const_op_name = saved_constant.operation();
|
||||
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
|
||||
if (node_name_and_attrs == node_attr_map.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Unable to find Const operation with name'", const_op_name,
|
||||
"' in SavedModel graphdef");
|
||||
}
|
||||
const AttrValueMap* attrs = node_name_and_attrs->second;
|
||||
const auto& attr_name_and_value = attrs->find("value");
|
||||
if (attr_name_and_value == attrs->end()) {
|
||||
return errors::FailedPrecondition("Unable to find Const operation '",
|
||||
const_op_name, "'s value attribute");
|
||||
}
|
||||
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
|
||||
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
|
||||
}
|
||||
|
||||
// Finds the "signatures" object in the object graph, and fills a mapping of
|
||||
// each signature's name to the corresponding function's node in the object
|
||||
// graph.
|
||||
Status GetSignaturesMap(const SavedObjectGraph& saved_objects,
|
||||
gtl::FlatMap<std::string, int>* signatures_map) {
|
||||
if (saved_objects.nodes().empty()) {
|
||||
return errors::FailedPrecondition("Saved Object Graph was empty.");
|
||||
}
|
||||
const SavedObject& root = saved_objects.nodes(0);
|
||||
const SavedObject* signatures = nullptr;
|
||||
for (const auto& child : root.children()) {
|
||||
if (child.local_name() == "signatures") {
|
||||
if (child.node_id() >= saved_objects.nodes().size()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Signature object had child node id ", child.node_id(),
|
||||
" which exceeds the size of the set of nodes");
|
||||
}
|
||||
signatures = &saved_objects.nodes(child.node_id());
|
||||
}
|
||||
}
|
||||
|
||||
// Some basic sanity checks that this object is actually our "signatures" map
|
||||
if (signatures == nullptr) {
|
||||
// This is where the "signatures" attribute is always set:
|
||||
// https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109
|
||||
return errors::FailedPrecondition(
|
||||
"SavedObjectGraph's root object must have a child 'signatures' object");
|
||||
}
|
||||
if (signatures->kind_case() != SavedObject::kUserObject) {
|
||||
return errors::FailedPrecondition(
|
||||
"Signatures must be a SavedObject of type UserObject.");
|
||||
}
|
||||
if (signatures->user_object().identifier() != "signature_map") {
|
||||
// This is where the string comes from:
|
||||
// https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220
|
||||
return errors::FailedPrecondition(
|
||||
"Signatures SavedObject must have identifier 'signature_map'.");
|
||||
}
|
||||
|
||||
for (const auto& child : signatures->children()) {
|
||||
(*signatures_map)[child.local_name()] = child.node_id();
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
// Perform some basic sanity checks on SavedConcreteFunction's input and
|
||||
// output signatures with respect to the corresponding FunctionDef's input
|
||||
// and output args.
|
||||
@ -99,6 +183,21 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) {
|
||||
// We only allow loading functions that have an annotated input signature,
|
||||
// which means there is 1:1 correspondence between tf.function
|
||||
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
|
||||
// the same restriction that MLIR has:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
|
||||
if (saved_function.concrete_functions_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"Only tf.functions annotated with an input signature are supported "
|
||||
"by SavedModelAPI. This means that there should only be a single "
|
||||
"ConcreteFunction per tf.function");
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
|
||||
@ -255,21 +354,18 @@ absl::optional<int> FindNodeAtPath(StringPiece path,
|
||||
return node_id;
|
||||
}
|
||||
|
||||
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>
|
||||
NodeToAttrMap(const tensorflow::GraphDef& graphdef) {
|
||||
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>
|
||||
result;
|
||||
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap(
|
||||
const tensorflow::GraphDef& graphdef) {
|
||||
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> result;
|
||||
for (const tensorflow::NodeDef& node : graphdef.node()) {
|
||||
result[node.name()] = &node.attr();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>
|
||||
gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
|
||||
FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) {
|
||||
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>
|
||||
gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
|
||||
result;
|
||||
for (const FunctionDef& function_def : library.function()) {
|
||||
result[function_def.signature().name()] = &function_def;
|
||||
@ -277,5 +373,154 @@ FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) {
|
||||
return result;
|
||||
}
|
||||
|
||||
Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
|
||||
ImmediateExecutionContext* context,
|
||||
const std::string& directory,
|
||||
PartiallyRevivedObjects* objects) {
|
||||
// This is needed to restore "Constant" nodes by looking up their
|
||||
// "Value" attribute.
|
||||
NodeAttrMap node_attr_map = NodeToAttrMap(metagraph.graph_def());
|
||||
|
||||
// These are needed for creating "Assets", by looking up their filenames.
|
||||
std::vector<AssetFileDef> assets;
|
||||
TF_RETURN_IF_ERROR(GetAssetFileDefs(metagraph, &assets));
|
||||
|
||||
// Signatures are needed for determining whether a function is a
|
||||
// SignatureDefFunction or not.
|
||||
gtl::FlatMap<std::string, int> signatures_map;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetSignaturesMap(metagraph.object_graph_def(), &signatures_map));
|
||||
|
||||
gtl::FlatMap<int, std::string> reversed_signatures_map;
|
||||
reversed_signatures_map.reserve(signatures_map.size());
|
||||
for (const auto& signature_key_and_node : signatures_map) {
|
||||
reversed_signatures_map.emplace(signature_key_and_node.second,
|
||||
signature_key_and_node.first);
|
||||
}
|
||||
|
||||
// FunctionDefs are needed to help construct
|
||||
// TFConcreteFunction/SignatureDefFunctions
|
||||
const FunctionDefMap function_def_map =
|
||||
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
|
||||
|
||||
// Iterate through all the saved objects, restoring objects (if we can) as we
|
||||
// go. For objects that dependencies on other objects (resources/functions),
|
||||
// we partially initialize "builders" that correspond to their currently known
|
||||
// state, and gradually fill them out in subsequent passes.
|
||||
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
|
||||
const SavedObject& node = metagraph.object_graph_def().nodes(i);
|
||||
if (node.kind_case() == SavedObject::kVariable) {
|
||||
std::unique_ptr<Variable> variable;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LoadSavedVariable(context, node.variable(), &variable));
|
||||
objects->variables[i] = std::move(variable);
|
||||
} else if (node.kind_case() == SavedObject::kConstant) {
|
||||
std::unique_ptr<Constant> constant;
|
||||
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
|
||||
node_attr_map, &constant));
|
||||
objects->constants[i] = std::move(constant);
|
||||
} else if (node.kind_case() == SavedObject::kAsset) {
|
||||
std::unique_ptr<Asset> asset;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LoadSavedAsset(context, node.asset(), directory, assets, &asset));
|
||||
objects->assets[i] = std::move(asset);
|
||||
} else if (node.kind_case() == SavedObject::kResource) {
|
||||
RestoredResourceRevivalState resource_revival_state;
|
||||
// We'll set the resource's functions in a subsequent pass, once we get
|
||||
// all functions in a partially revived state.
|
||||
resource_revival_state.device = node.resource().device();
|
||||
objects->restored_resources[i] = std::move(resource_revival_state);
|
||||
} else if (node.kind_case() == SavedObject::kFunction) {
|
||||
// Get the SavedFunction node and validate it has a single concrete func.
|
||||
const SavedFunction& saved_function = node.function();
|
||||
TF_RETURN_IF_ERROR(ValidateSingleConcreteFunction(saved_function));
|
||||
|
||||
// Retrieve related function information.
|
||||
const std::string& function_name = saved_function.concrete_functions(0);
|
||||
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||
const SavedConcreteFunction& saved_concrete_func =
|
||||
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||
const FunctionSpec& function_spec = saved_function.function_spec();
|
||||
|
||||
// Construct either a SignatureDefFunctionBuilder or a
|
||||
// ConcreteFunctionBuilder, depending on whether this node was a child
|
||||
// of the "signatures" attribute from root object.
|
||||
auto reverse_signature_iter = reversed_signatures_map.find(i);
|
||||
if (reverse_signature_iter != reversed_signatures_map.end()) {
|
||||
TFSignatureDefFunctionRevivalState func_revival_state;
|
||||
func_revival_state.node_id = i;
|
||||
func_revival_state.fdef = function_def;
|
||||
func_revival_state.saved_concrete_func = &saved_concrete_func;
|
||||
func_revival_state.signature_key = reverse_signature_iter->second;
|
||||
objects->signature_def_functions[i] = std::move(func_revival_state);
|
||||
} else {
|
||||
TFConcreteFunctionRevivalState func_revival_state;
|
||||
func_revival_state.node_id = i;
|
||||
func_revival_state.fdef = function_def;
|
||||
func_revival_state.saved_concrete_func = &saved_concrete_func;
|
||||
func_revival_state.function_spec = &function_spec;
|
||||
objects->concrete_functions[i] = std::move(func_revival_state);
|
||||
}
|
||||
} else if (node.kind_case() == SavedObject::kBareConcreteFunction) {
|
||||
const SavedBareConcreteFunction& bare_cf = node.bare_concrete_function();
|
||||
|
||||
// Retrieve related function information.
|
||||
const std::string& function_name = bare_cf.concrete_function_name();
|
||||
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||
const SavedConcreteFunction& saved_concrete_func =
|
||||
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||
|
||||
// Check whether this is a SignatureDefFunction, or not.
|
||||
auto reverse_signature_iter = reversed_signatures_map.find(i);
|
||||
if (reverse_signature_iter != reversed_signatures_map.end()) {
|
||||
TFSignatureDefFunctionRevivalState func_revival_state;
|
||||
func_revival_state.node_id = i;
|
||||
func_revival_state.fdef = function_def;
|
||||
func_revival_state.saved_concrete_func = &saved_concrete_func;
|
||||
func_revival_state.signature_key = reverse_signature_iter->second;
|
||||
objects->signature_def_functions[i] = std::move(func_revival_state);
|
||||
} else {
|
||||
TFConcreteFunctionRevivalState func_revival_state;
|
||||
func_revival_state.node_id = i;
|
||||
func_revival_state.fdef = function_def;
|
||||
func_revival_state.saved_concrete_func = &saved_concrete_func;
|
||||
objects->concrete_functions[i] = std::move(func_revival_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we've partially restored all functions, we can have resources
|
||||
// point to them
|
||||
for (auto& node_and_resource_revival_state : objects->restored_resources) {
|
||||
int node_id = node_and_resource_revival_state.first;
|
||||
const SavedObjectGraph& obj_graph = metagraph.object_graph_def();
|
||||
const SavedObject& node = obj_graph.nodes(node_id);
|
||||
RestoredResourceRevivalState& resource =
|
||||
node_and_resource_revival_state.second;
|
||||
for (const TrackableObjectGraph::TrackableObject::ObjectReference& child :
|
||||
node.children()) {
|
||||
int child_node_id = child.node_id();
|
||||
// Note(bmzhao): The expected functions saved by a resource object are:
|
||||
// "_create_resource", "_initialize", and "_destroy_resource".
|
||||
// https://github.com/tensorflow/tensorflow/blob/ad66f588c1666ade8051feb42811fa27b285271c/tensorflow/python/training/tracking/tracking.py#L277-L281
|
||||
if (child.local_name() == "_create_resource" &&
|
||||
obj_graph.nodes(child.node_id()).kind_case() ==
|
||||
SavedObject::kFunction) {
|
||||
resource.create_resource = &objects->concrete_functions[child_node_id];
|
||||
} else if (child.local_name() == "_initialize" &&
|
||||
obj_graph.nodes(child.node_id()).kind_case() ==
|
||||
SavedObject::kFunction) {
|
||||
resource.initialize = &objects->concrete_functions[child_node_id];
|
||||
} else if (child.local_name() == "_destroy_resource" &&
|
||||
obj_graph.nodes(child.node_id()).kind_case() ==
|
||||
SavedObject::kFunction) {
|
||||
resource.destroy_resource = &objects->concrete_functions[child_node_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
@ -27,10 +27,12 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
@ -84,15 +86,22 @@ absl::optional<int> FindNodeAtPath(StringPiece path,
|
||||
|
||||
// Maps each node in `graphdef` to its corresponding Attribute Map.
|
||||
// Callers must ensure that `graphdef` outlives the returned map.
|
||||
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>
|
||||
NodeToAttrMap(const tensorflow::GraphDef& graphdef);
|
||||
gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap(
|
||||
const tensorflow::GraphDef& graphdef);
|
||||
|
||||
// Maps the name of each FunctionDef in `library` to its corresponding
|
||||
// FunctionDef. Callers must ensure `library` outlives the returned map.
|
||||
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>
|
||||
gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
|
||||
FunctionNameToFunctionDefMap(const FunctionDefLibrary& library);
|
||||
|
||||
// Walks through the SavedObjectGraph in metagraph, and restores all nodes
|
||||
// (except "UserDefinedObjects") with their corresponding type in
|
||||
// "PartiallyRevivedObjects".
|
||||
Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
|
||||
ImmediateExecutionContext* context,
|
||||
const std::string& directory,
|
||||
PartiallyRevivedObjects* objects);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
@ -30,6 +29,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
@ -37,7 +39,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -46,6 +47,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
@ -62,142 +64,15 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
|
||||
using FunctionDefMap =
|
||||
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>;
|
||||
|
||||
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
|
||||
// Graphdef
|
||||
using NodeAttrMap =
|
||||
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>;
|
||||
|
||||
// Maps from Node ID to an "Revived Object" implementing
|
||||
// "TensorHandleConvertible"
|
||||
using RevivedObjectMap =
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>;
|
||||
using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>;
|
||||
|
||||
// Maps from a functiondef's name to the corresponding "TFConcreteFunction"
|
||||
using ConcreteFunctionMap =
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>;
|
||||
using FlatTensorFunctionMap =
|
||||
gtl::FlatMap<std::string, std::unique_ptr<FlatTensorFunction>>;
|
||||
|
||||
namespace {
|
||||
|
||||
Status ConstantFromSavedConstant(
|
||||
ImmediateExecutionContext* ctx,
|
||||
const tensorflow::SavedConstant& saved_constant,
|
||||
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
|
||||
const std::string& const_op_name = saved_constant.operation();
|
||||
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
|
||||
if (node_name_and_attrs == node_attr_map.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Unable to find Const operation with name'", const_op_name,
|
||||
"' in SavedModel graphdef");
|
||||
}
|
||||
const AttrValueMap* attrs = node_name_and_attrs->second;
|
||||
const auto& attr_name_and_value = attrs->find("value");
|
||||
if (attr_name_and_value == attrs->end()) {
|
||||
return errors::FailedPrecondition("Unable to find Const operation '",
|
||||
const_op_name, "'s value attribute");
|
||||
}
|
||||
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
|
||||
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
|
||||
}
|
||||
|
||||
// Restores all non-function objects in the SavedModel's object graph.
|
||||
// This function walks through the metagraph's saved object graph, and
|
||||
// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and
|
||||
// SavedResources. These are returned via the `out` parameter.
|
||||
Status ReviveObjects(
|
||||
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
|
||||
const std::string& directory,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
|
||||
revived_objects) {
|
||||
// This is needed to restore "Constant" nodes by looking up their
|
||||
// "Value" attribute.
|
||||
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
|
||||
|
||||
// These are needed for creating "Assets", by looking up their filenames.
|
||||
std::vector<AssetFileDef> assets;
|
||||
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(metagraph, &assets));
|
||||
|
||||
// Iterate through all the saved objects, restoring objects as we go.
|
||||
// We don't recreate functions until all other objects have been created.
|
||||
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
|
||||
const SavedObject& node = metagraph.object_graph_def().nodes(i);
|
||||
if (node.kind_case() == SavedObject::kVariable) {
|
||||
std::unique_ptr<Variable> variable;
|
||||
TF_RETURN_IF_ERROR(
|
||||
internal::LoadSavedVariable(context, node.variable(), &variable));
|
||||
(*revived_objects)[i] = std::move(variable);
|
||||
} else if (node.kind_case() == SavedObject::kConstant) {
|
||||
std::unique_ptr<Constant> constant;
|
||||
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
|
||||
node_attr_map, &constant));
|
||||
(*revived_objects)[i] = std::move(constant);
|
||||
} else if (node.kind_case() == SavedObject::kAsset) {
|
||||
std::unique_ptr<Asset> asset;
|
||||
TF_RETURN_IF_ERROR(internal::LoadSavedAsset(context, node.asset(),
|
||||
directory, assets, &asset));
|
||||
(*revived_objects)[i] = std::move(asset);
|
||||
} else if (node.kind_case() == SavedObject::kResource) {
|
||||
// TODO(bmzhao): Figure out how resource loading works and implement it
|
||||
return errors::Unimplemented(
|
||||
"SavedResource loading is not implemented yet");
|
||||
}
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status ReviveFunctions(const MetaGraphDef& metagraph,
|
||||
const RevivedObjectMap& revived_objects,
|
||||
ImmediateExecutionContext* context,
|
||||
ConcreteFunctionMap* restored_functions) {
|
||||
const FunctionDefMap function_def_map =
|
||||
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
|
||||
|
||||
// Iterate through all objects, only examining functions.
|
||||
for (const SavedObject& node : metagraph.object_graph_def().nodes()) {
|
||||
if (node.kind_case() == SavedObject::kBareConcreteFunction) {
|
||||
const std::string& function_name =
|
||||
node.bare_concrete_function().concrete_function_name();
|
||||
|
||||
const SavedConcreteFunction& saved_concrete_function =
|
||||
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||
|
||||
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||
std::unique_ptr<TFConcreteFunction> concrete_function;
|
||||
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
|
||||
saved_concrete_function, function_def, revived_objects, context,
|
||||
&concrete_function));
|
||||
(*restored_functions)[function_name] = std::move(concrete_function);
|
||||
} else if (node.kind_case() == SavedObject::kFunction) {
|
||||
// We only allow loading functions that have an annotated input signature,
|
||||
// which means there is 1:1 correspondence between tf.function
|
||||
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
|
||||
// the same restriction that MLIR has:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
|
||||
const SavedFunction& saved_function = node.function();
|
||||
if (saved_function.concrete_functions_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"Only tf.functions annotated with an input signature are supported "
|
||||
"by SavedModelAPI. This means that there should only be a single "
|
||||
"ConcreteFunction per tf.function");
|
||||
}
|
||||
const std::string& function_name = saved_function.concrete_functions(0);
|
||||
const SavedConcreteFunction& saved_concrete_function =
|
||||
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||
|
||||
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> concrete_function;
|
||||
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
|
||||
saved_concrete_function, function_def, revived_objects, context,
|
||||
&concrete_function));
|
||||
(*restored_functions)[function_name] = std::move(concrete_function);
|
||||
}
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
const TrackableObjectGraph::TrackableObject::SerializedTensor*
|
||||
FindSerializedTensorInTrackable(
|
||||
@ -234,7 +109,7 @@ FindSerializedTensorInTrackable(
|
||||
// overridden "restore" method:
|
||||
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
|
||||
Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
|
||||
const RevivedObjectMap& revived_objects,
|
||||
const RevivedObjects& revived_objects,
|
||||
const std::string& directory,
|
||||
ImmediateExecutionContext* context) {
|
||||
// TODO(bmzhao): Batch up all the restores into a single restore op per
|
||||
@ -254,8 +129,7 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Variable* variable =
|
||||
down_cast<Variable*>(revived_objects.at(node).get());
|
||||
Variable* variable = revived_objects.variables.at(node).get();
|
||||
|
||||
// Restore the tensor's value from the checkpoint
|
||||
const TrackableObjectGraph::TrackableObject::SerializedTensor*
|
||||
@ -289,6 +163,14 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status InitializeAllResources(const RevivedObjects& revived) {
|
||||
for (const auto& node_and_resource : revived.restored_resources) {
|
||||
const RestoredResource& resource = node_and_resource.second;
|
||||
TF_RETURN_IF_ERROR(resource.Initialize());
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status TFSavedModelAPI::GetFunction(const std::string& function_path,
|
||||
@ -299,20 +181,12 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path,
|
||||
return errors::NotFound("No saved object found at path ", function_path);
|
||||
}
|
||||
|
||||
const SavedObject& object = bundle_.saved_object_graph().nodes(*node);
|
||||
if (object.kind_case() == SavedObject::kBareConcreteFunction) {
|
||||
*function =
|
||||
concrete_functions_
|
||||
.at(object.bare_concrete_function().concrete_function_name())
|
||||
.get();
|
||||
} else if (object.kind_case() == SavedObject::kFunction) {
|
||||
*function =
|
||||
concrete_functions_.at(object.function().concrete_functions(0)).get();
|
||||
} else {
|
||||
return errors::InvalidArgument(function_path,
|
||||
" is not a path to a Function.");
|
||||
auto function_iter = revived_objects_.concrete_functions.find(*node);
|
||||
if (function_iter == revived_objects_.concrete_functions.end()) {
|
||||
return errors::NotFound("No function found at path ", function_path);
|
||||
}
|
||||
|
||||
*function = function_iter->second.get();
|
||||
return Status();
|
||||
}
|
||||
|
||||
@ -325,8 +199,8 @@ Status TFSavedModelAPI::GetSignatureDefFunction(
|
||||
|
||||
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||
std::vector<ConcreteFunction*> result;
|
||||
result.reserve(concrete_functions_.size());
|
||||
for (auto& index_and_function : concrete_functions_) {
|
||||
result.reserve(revived_objects_.concrete_functions.size());
|
||||
for (auto& index_and_function : revived_objects_.concrete_functions) {
|
||||
result.push_back(index_and_function.second.get());
|
||||
}
|
||||
return result;
|
||||
@ -340,34 +214,21 @@ Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
|
||||
return errors::NotFound("No saved object found at path ", variable_path);
|
||||
}
|
||||
|
||||
const SavedObject& object = bundle_.saved_object_graph().nodes(*node);
|
||||
|
||||
if (object.kind_case() == SavedObject::kVariable) {
|
||||
auto iter = revived_objects_.find(*node);
|
||||
if (iter == revived_objects_.end()) {
|
||||
return errors::Internal("Variable ", variable_path,
|
||||
" was not properly revived.");
|
||||
}
|
||||
*variable = static_cast<Variable*>(iter->second.get());
|
||||
return Status();
|
||||
auto variables_iter = revived_objects_.variables.find(*node);
|
||||
if (variables_iter == revived_objects_.variables.end()) {
|
||||
return errors::NotFound("No variable found at path ", variable_path);
|
||||
}
|
||||
|
||||
*variable = nullptr;
|
||||
return errors::InvalidArgument(
|
||||
variable_path, " is not a path to a Variable (kind=", object.kind_case(),
|
||||
")");
|
||||
*variable = variables_iter->second.get();
|
||||
return Status();
|
||||
}
|
||||
|
||||
TFSavedModelAPI::TFSavedModelAPI(
|
||||
const std::string& directory, SavedModelV2Bundle bundle,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
revived_objects,
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||
concrete_functions)
|
||||
TFSavedModelAPI::TFSavedModelAPI(const std::string& directory,
|
||||
SavedModelV2Bundle bundle,
|
||||
RevivedObjects revived_objects)
|
||||
: directory_(directory),
|
||||
bundle_(std::move(bundle)),
|
||||
revived_objects_(std::move(revived_objects)),
|
||||
concrete_functions_(std::move(concrete_functions)) {}
|
||||
revived_objects_(std::move(revived_objects)) {}
|
||||
|
||||
Status TFSavedModelAPI::Load(
|
||||
const std::string& directory,
|
||||
@ -388,28 +249,25 @@ Status TFSavedModelAPI::Load(
|
||||
// This occurs in python here:
|
||||
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
|
||||
|
||||
RevivedObjectMap revived_objects;
|
||||
TF_RETURN_IF_ERROR(ReviveObjects(bundle.meta_graph_def(), context, directory,
|
||||
&revived_objects));
|
||||
// Step 1: For each node in the graph, we should initialize an object of the
|
||||
// corresponding type. For objects that depend on the initialization of other
|
||||
// objects (like functions which capture resources), we will initialize them
|
||||
// in step 2.
|
||||
PartiallyRevivedObjects partially_revived_objects;
|
||||
TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects(
|
||||
bundle.meta_graph_def(), context, directory, &partially_revived_objects));
|
||||
|
||||
// TODO(bmzhao): When we later add support for loading resources, we need to
|
||||
// handle the case where materializing a function's captures requires invoking
|
||||
// other functions. This occurs when retrieving the resource handle for a
|
||||
// TrackableResource:
|
||||
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
|
||||
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
|
||||
// This requires restoring functions in a topological sort order by capture
|
||||
// dependencies.
|
||||
ConcreteFunctionMap function_map;
|
||||
TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects,
|
||||
context, &function_map));
|
||||
RevivedObjects revived_objects;
|
||||
TF_RETURN_IF_ERROR(partially_revived_objects.Build(
|
||||
context, bundle.saved_object_graph(), &revived_objects));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestoreCheckpoint(&bundle, revived_objects, directory, context));
|
||||
|
||||
TF_RETURN_IF_ERROR(InitializeAllResources(revived_objects));
|
||||
|
||||
out->reset(new TFSavedModelAPI(directory, std::move(bundle),
|
||||
std::move(revived_objects),
|
||||
std::move(function_map)));
|
||||
std::move(revived_objects)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
@ -72,19 +73,12 @@ class TFSavedModelAPI : public SavedModelAPI {
|
||||
Status GetVariable(const std::string& variable_path, Variable** variable);
|
||||
|
||||
private:
|
||||
TFSavedModelAPI(
|
||||
const std::string& directory, SavedModelV2Bundle bundle,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
revived_objects,
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||
concrete_functions);
|
||||
TFSavedModelAPI(const std::string& directory, SavedModelV2Bundle bundle,
|
||||
RevivedObjects revived_objects);
|
||||
|
||||
std::string directory_;
|
||||
SavedModelV2Bundle bundle_;
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
revived_objects_;
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||
concrete_functions_;
|
||||
RevivedObjects revived_objects_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
@ -196,6 +197,142 @@ TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsStaticHashtableSavedModel) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("StaticHashTableModule");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TF_ConcreteFunction* lookup_fn =
|
||||
TF_GetSavedModelConcreteFunction(saved_model, "lookup", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// Note(bmzhao): Based on static_hashtable_asset.txt, we expect the following
|
||||
// mapping:
|
||||
// "foo" -> 0
|
||||
// "bar" -> 1
|
||||
// "baz" -> 2
|
||||
// "wombat" -> 3
|
||||
// all other strings -> -1
|
||||
|
||||
// Call lookup function with input "foo", expecting an output of 0
|
||||
{
|
||||
std::vector<TFE_TensorHandle*> lookup_fn_inputs;
|
||||
TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("foo"));
|
||||
lookup_fn_inputs.push_back(input_foo);
|
||||
|
||||
TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
|
||||
lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
|
||||
// inputs + outputs a function has.
|
||||
TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
tensorflow::int64* output_value =
|
||||
static_cast<tensorflow::int64*>(TF_TensorData(result));
|
||||
EXPECT_EQ(*output_value, 0);
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(input_foo);
|
||||
TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
|
||||
TFE_DeleteOp(lookup_op);
|
||||
}
|
||||
|
||||
// Call lookup function with input "baz", expecting an output of 2
|
||||
{
|
||||
std::vector<TFE_TensorHandle*> lookup_fn_inputs;
|
||||
TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("baz"));
|
||||
lookup_fn_inputs.push_back(input_foo);
|
||||
|
||||
TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
|
||||
lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
|
||||
// inputs + outputs a function has.
|
||||
TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
tensorflow::int64* output_value =
|
||||
static_cast<tensorflow::int64*>(TF_TensorData(result));
|
||||
EXPECT_EQ(*output_value, 2);
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(input_foo);
|
||||
TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
|
||||
TFE_DeleteOp(lookup_op);
|
||||
}
|
||||
|
||||
// Call lookup function w/input "NON-EXISTENT-KEY", expecting an output of -1
|
||||
{
|
||||
std::vector<TFE_TensorHandle*> lookup_fn_inputs;
|
||||
TFE_TensorHandle* input_foo =
|
||||
TestScalarTensorHandle(ctx, tstring("NON-EXISTENT-KEY"));
|
||||
lookup_fn_inputs.push_back(input_foo);
|
||||
|
||||
TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
|
||||
lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
|
||||
// inputs + outputs a function has.
|
||||
TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
tensorflow::int64* output_value =
|
||||
static_cast<tensorflow::int64*>(TF_TensorData(result));
|
||||
EXPECT_EQ(*output_value, -1);
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(input_foo);
|
||||
TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
|
||||
TFE_DeleteOp(lookup_op);
|
||||
}
|
||||
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
|
@ -213,6 +213,7 @@ py_binary(
|
||||
srcs = ["testdata/generate_saved_models.py"],
|
||||
data = [
|
||||
":saved_model_asset_data",
|
||||
":saved_model_static_hashtable_asset_data",
|
||||
],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
@ -220,6 +221,7 @@ py_binary(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
@ -243,6 +245,7 @@ filegroup(
|
||||
"testdata/half_plus_two_v2/**",
|
||||
"testdata/x_plus_y_v2_debuginfo/**",
|
||||
"testdata/CyclicModule/**",
|
||||
"testdata/StaticHashTableModule/**",
|
||||
"testdata/VarsAndArithmeticObjectGraph/**",
|
||||
"testdata/fuzz_generated/**",
|
||||
]),
|
||||
@ -260,6 +263,13 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "saved_model_static_hashtable_asset_data",
|
||||
srcs = [
|
||||
"testdata/static_hashtable_asset.txt",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
glob([
|
||||
"testdata/half_plus_two_pbtxt/**",
|
||||
|
4
tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt
vendored
Normal file
4
tensorflow/cc/saved_model/testdata/StaticHashTableModule/assets/static_hashtable_asset.txt
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
foo
|
||||
bar
|
||||
baz
|
||||
wombat
|
BIN
tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/StaticHashTableModule/saved_model.pb
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/StaticHashTableModule/variables/variables.index
vendored
Normal file
Binary file not shown.
@ -30,6 +30,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import save_options
|
||||
@ -82,10 +83,31 @@ class AssetModule(module.Module):
|
||||
return io_ops.read_file(self.asset)
|
||||
|
||||
|
||||
class StaticHashTableModule(module.Module):
|
||||
"""A module with an Asset, StaticHashTable, and a lookup function."""
|
||||
|
||||
def __init__(self):
|
||||
self.asset = tracking.Asset(
|
||||
test.test_src_dir_path(
|
||||
"cc/saved_model/testdata/static_hashtable_asset.txt"))
|
||||
self.table = lookup_ops.StaticHashTable(
|
||||
lookup_ops.TextFileInitializer(self.asset, dtypes.string,
|
||||
lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||
dtypes.int64,
|
||||
lookup_ops.TextFileIndex.LINE_NUMBER),
|
||||
-1)
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec(shape=None, dtype=dtypes.string)])
|
||||
def lookup(self, word):
|
||||
return self.table.lookup(word)
|
||||
|
||||
|
||||
MODULE_CTORS = {
|
||||
"VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph,
|
||||
"CyclicModule": CyclicModule,
|
||||
"AssetModule": AssetModule,
|
||||
"StaticHashTableModule": StaticHashTableModule,
|
||||
}
|
||||
|
||||
|
||||
|
4
tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt
vendored
Normal file
4
tensorflow/cc/saved_model/testdata/static_hashtable_asset.txt
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
foo
|
||||
bar
|
||||
baz
|
||||
wombat
|
Loading…
Reference in New Issue
Block a user