Experimental implementation of saved model C API. This supports loading constants, unpartitioned variables, and tf.functions annotated with an input signature from a TF 2 saved model. See RFC 207: https://github.com/tensorflow/community/pull/207
Future CLs will flesh out some of the missing pieces (specifying a tag, loading models with resources/assets, batching tensor restores per device, etc). PiperOrigin-RevId: 321262527 Change-Id: I76dbe4617acd7bdbe5f093e6e22b328842a65780
This commit is contained in:
parent
0aa0c33c58
commit
bfeba4d9ec
@ -113,8 +113,23 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":concrete_function",
|
":concrete_function",
|
||||||
":saved_model_api",
|
":saved_model_api",
|
||||||
|
":saved_model_utils",
|
||||||
|
"//tensorflow/c:tensor_interface",
|
||||||
"//tensorflow/c/eager:immediate_execution_context",
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||||
|
"//tensorflow/cc/saved_model:bundle_v2",
|
||||||
|
"//tensorflow/cc/saved_model:constants",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -15,47 +15,360 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
|
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||||
|
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||||
|
#include "tensorflow/cc/saved_model/constants.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/path.h"
|
||||||
|
#include "tensorflow/core/platform/stringpiece.h"
|
||||||
|
#include "tensorflow/core/platform/tstring.h"
|
||||||
|
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
|
||||||
|
using FunctionDefMap =
|
||||||
|
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
|
||||||
|
StringPieceHasher>;
|
||||||
|
|
||||||
|
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
|
||||||
|
// Graphdef
|
||||||
|
using NodeAttrMap =
|
||||||
|
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>;
|
||||||
|
|
||||||
|
// Maps from Node ID to an "Revived Object" implementing
|
||||||
|
// "TensorHandleConvertible"
|
||||||
|
using RevivedObjectMap =
|
||||||
|
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>;
|
||||||
|
|
||||||
|
// Maps from a functiondef's name to the corresponding "TFConcreteFunction"
|
||||||
|
using ConcreteFunctionMap =
|
||||||
|
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status ConstantFromSavedConstant(
|
||||||
|
ImmediateExecutionContext* ctx,
|
||||||
|
const tensorflow::SavedConstant& saved_constant,
|
||||||
|
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
|
||||||
|
const std::string& const_op_name = saved_constant.operation();
|
||||||
|
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
|
||||||
|
if (node_name_and_attrs == node_attr_map.end()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Unable to find Const operation with name'", const_op_name,
|
||||||
|
"' in SavedModel graphdef");
|
||||||
|
}
|
||||||
|
const AttrValueMap* attrs = node_name_and_attrs->second;
|
||||||
|
const auto& attr_name_and_value = attrs->find("value");
|
||||||
|
if (attr_name_and_value == attrs->end()) {
|
||||||
|
return errors::FailedPrecondition("Unable to find Const operation '",
|
||||||
|
const_op_name, "'s value attribute");
|
||||||
|
}
|
||||||
|
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
|
||||||
|
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restores all non-function objects in the SavedModel's object graph.
|
||||||
|
// This function walks through the metagraph's saved object graph, and
|
||||||
|
// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and
|
||||||
|
// SavedResources. These are returned via the `out` parameter.
|
||||||
|
Status ReviveObjects(
|
||||||
|
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
|
||||||
|
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
|
||||||
|
revived_objects) {
|
||||||
|
// This is needed to restore "Constant" nodes by looking up their
|
||||||
|
// "Value" attribute.
|
||||||
|
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
|
||||||
|
|
||||||
|
// Iterate through all the saved objects, restoring objects as we go.
|
||||||
|
// We don't recreate functions until all other objects have been created.
|
||||||
|
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
|
||||||
|
const SavedObject& node = metagraph.object_graph_def().nodes(i);
|
||||||
|
if (node.kind_case() == SavedObject::kVariable) {
|
||||||
|
std::unique_ptr<Variable> variable;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
internal::LoadSavedVariable(context, node.variable(), &variable));
|
||||||
|
(*revived_objects)[i] = std::move(variable);
|
||||||
|
} else if (node.kind_case() == SavedObject::kConstant) {
|
||||||
|
std::unique_ptr<Constant> constant;
|
||||||
|
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
|
||||||
|
node_attr_map, &constant));
|
||||||
|
(*revived_objects)[i] = std::move(constant);
|
||||||
|
} else if (node.kind_case() == SavedObject::kAsset) {
|
||||||
|
// TODO(bmzhao): Implement Asset C++ class. This should be just recreating
|
||||||
|
// the full path to the asset file:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396
|
||||||
|
// and storing it as a string tensor:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325
|
||||||
|
return errors::Unimplemented("SavedAsset loading is not implemented yet");
|
||||||
|
} else if (node.kind_case() == SavedObject::kResource) {
|
||||||
|
// TODO(bmzhao): Figure out how resource loading works and implement it
|
||||||
|
return errors::Unimplemented(
|
||||||
|
"SavedResource loading is not implemented yet");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ReviveFunctions(const MetaGraphDef& metagraph,
|
||||||
|
const RevivedObjectMap& revived_objects,
|
||||||
|
ImmediateExecutionContext* context,
|
||||||
|
ConcreteFunctionMap* restored_functions) {
|
||||||
|
const FunctionDefMap function_def_map =
|
||||||
|
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
|
||||||
|
|
||||||
|
// Iterate through all objects, only examining functions.
|
||||||
|
for (const SavedObject& node : metagraph.object_graph_def().nodes()) {
|
||||||
|
if (node.kind_case() == SavedObject::kBareConcreteFunction) {
|
||||||
|
const std::string& function_name =
|
||||||
|
node.bare_concrete_function().concrete_function_name();
|
||||||
|
|
||||||
|
const SavedConcreteFunction& saved_concrete_function =
|
||||||
|
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||||
|
|
||||||
|
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||||
|
std::unique_ptr<TFConcreteFunction> concrete_function;
|
||||||
|
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
|
||||||
|
saved_concrete_function, function_def, revived_objects, context,
|
||||||
|
&concrete_function));
|
||||||
|
(*restored_functions)[function_name] = std::move(concrete_function);
|
||||||
|
} else if (node.kind_case() == SavedObject::kFunction) {
|
||||||
|
// We only allow loading functions that have an annotated input signature,
|
||||||
|
// which means there is 1:1 correspondence between tf.function
|
||||||
|
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
|
||||||
|
// the same restriction that MLIR has:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
|
||||||
|
const SavedFunction& saved_function = node.function();
|
||||||
|
if (saved_function.concrete_functions_size() != 1) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Only tf.functions annotated with an input signature are supported "
|
||||||
|
"by SavedModelAPI. This means that there should only be a single "
|
||||||
|
"ConcreteFunction per tf.function");
|
||||||
|
}
|
||||||
|
const std::string& function_name = saved_function.concrete_functions(0);
|
||||||
|
const SavedConcreteFunction& saved_concrete_function =
|
||||||
|
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||||
|
|
||||||
|
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||||
|
|
||||||
|
std::unique_ptr<TFConcreteFunction> concrete_function;
|
||||||
|
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
|
||||||
|
saved_concrete_function, function_def, revived_objects, context,
|
||||||
|
&concrete_function));
|
||||||
|
(*restored_functions)[function_name] = std::move(concrete_function);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
const TrackableObjectGraph::TrackableObject::SerializedTensor*
|
||||||
|
FindSerializedTensorInTrackable(
|
||||||
|
const TrackableObjectGraph::TrackableObject& trackable_object,
|
||||||
|
absl::string_view name) {
|
||||||
|
for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
|
||||||
|
if (maybe_serialized_tensor.name() == name) {
|
||||||
|
return &maybe_serialized_tensor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function reads the Checkpoint embedded in the SavedModel, and calls the
|
||||||
|
// appropriate Restore ops on each of the variables.
|
||||||
|
// Note(bmzhao): Conceptually, objects that contain checkpointable state
|
||||||
|
// implement the "_gather_saveables_for_checkpoint" method
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983
|
||||||
|
// which returns a dict of string key -> EITHER:
|
||||||
|
// 1. python callable (taking a checkpoint key) returning SaveableObject OR
|
||||||
|
// 2. variable (partitioned/resource/reference or otherwise)
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58.
|
||||||
|
// The string key becomes the "name" attribute of the SerializedTensor proto
|
||||||
|
// in the TrackableObjectGraph,
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26
|
||||||
|
// And the checkpoint_key is a globally unique string derived from this name:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241
|
||||||
|
// SaveableObjects model the information needed to pass to the SaveV2/RestoreV2
|
||||||
|
// ops via their SaveSpec members
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21,
|
||||||
|
// which contain the "real" checkpoint keys into the TensorBundle SSTable.
|
||||||
|
// They also contain the logic needed to take the restored tensors from
|
||||||
|
// RestoreV2 and load them back into the "object" they came from via their
|
||||||
|
// overridden "restore" method:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
|
||||||
|
Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
|
||||||
|
const RevivedObjectMap& revived_objects,
|
||||||
|
const std::string& directory,
|
||||||
|
ImmediateExecutionContext* context) {
|
||||||
|
// TODO(bmzhao): Batch up all the restores into a single restore op per
|
||||||
|
// device, following logic in MultiDeviceSaver.
|
||||||
|
TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore(
|
||||||
|
[&revived_objects, &directory, context, bundle](
|
||||||
|
int node, const TrackableObjectGraph::TrackableObject& trackable) {
|
||||||
|
if (bundle->saved_object_graph().nodes(node).kind_case() !=
|
||||||
|
SavedObject::kVariable) {
|
||||||
|
// TODO(bmzhao): This requires using the newly added Save/Restore
|
||||||
|
// functions from
|
||||||
|
// https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
|
||||||
|
return errors::Unimplemented(
|
||||||
|
"Restoring non-variable objects has not been implemented yet. ");
|
||||||
|
}
|
||||||
|
|
||||||
|
Variable* variable =
|
||||||
|
down_cast<Variable*>(revived_objects.at(node).get());
|
||||||
|
|
||||||
|
// Restore the tensor's value from the checkpoint
|
||||||
|
const TrackableObjectGraph::TrackableObject::SerializedTensor*
|
||||||
|
attribute =
|
||||||
|
FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE");
|
||||||
|
if (attribute == nullptr) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Could not find SerializedTensor with name VARIABLE_VALUE for "
|
||||||
|
"saved variable");
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string& checkpoint_key = attribute->checkpoint_key();
|
||||||
|
std::string variables_path_prefix =
|
||||||
|
io::JoinPath(directory, kSavedModelVariablesDirectory,
|
||||||
|
kSavedModelVariablesFilename);
|
||||||
|
ImmediateTensorHandlePtr restored_output;
|
||||||
|
TF_RETURN_IF_ERROR(internal::SingleRestore(
|
||||||
|
context, variables_path_prefix, checkpoint_key, variable->dtype(),
|
||||||
|
&restored_output));
|
||||||
|
|
||||||
|
// Assign the restored tensor's value to the variable
|
||||||
|
return variable->Assign(restored_output.get());
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
Status TFSavedModelAPI::GetFunction(const std::string& function_path,
|
Status TFSavedModelAPI::GetFunction(const std::string& function_path,
|
||||||
ConcreteFunction** function) {
|
ConcreteFunction** function) {
|
||||||
// TODO(bmzhao): Add support for retrieving a function.
|
const SavedObject* object =
|
||||||
return errors::Unimplemented(
|
internal::FindNodeAtPath(function_path, bundle_.saved_object_graph());
|
||||||
"Retrieving functions is unimplemented currently");
|
if (object == nullptr) {
|
||||||
|
return errors::NotFound("No saved object found at path ", function_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (object->kind_case() == SavedObject::kBareConcreteFunction) {
|
||||||
|
*function =
|
||||||
|
concrete_functions_
|
||||||
|
.at(object->bare_concrete_function().concrete_function_name())
|
||||||
|
.get();
|
||||||
|
} else if (object->kind_case() == SavedObject::kFunction) {
|
||||||
|
*function =
|
||||||
|
concrete_functions_.at(object->function().concrete_functions(0)).get();
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument(function_path,
|
||||||
|
" is not a path to a Function.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TFSavedModelAPI::GetSignatureDefFunction(
|
Status TFSavedModelAPI::GetSignatureDefFunction(
|
||||||
const std::string& signature_def_key, ConcreteFunction** function) {
|
const std::string& signature_def_key, ConcreteFunction** function) {
|
||||||
// TODO(bmzhao): Add support for retrieving a signaturedef function.
|
// TODO(bmzhao): Add support for retrieving a signaturedef function.
|
||||||
return errors::Unimplemented(
|
return errors::Unimplemented(
|
||||||
"Retrieving functions is unimplemented currently");
|
"Retrieving SignatureDef functions is unimplemented currently");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||||
std::vector<ConcreteFunction*> result;
|
std::vector<ConcreteFunction*> result;
|
||||||
result.reserve(functions_.size());
|
result.reserve(concrete_functions_.size());
|
||||||
for (ConcreteFunction& function : functions_) {
|
for (auto& index_and_function : concrete_functions_) {
|
||||||
result.push_back(&function);
|
result.push_back(index_and_function.second.get());
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFSavedModelAPI::TFSavedModelAPI(
|
||||||
|
const std::string& directory, SavedModelV2Bundle bundle,
|
||||||
|
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||||
|
revived_objects,
|
||||||
|
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||||
|
concrete_functions)
|
||||||
|
: directory_(directory),
|
||||||
|
bundle_(std::move(bundle)),
|
||||||
|
revived_objects_(std::move(revived_objects)),
|
||||||
|
concrete_functions_(std::move(concrete_functions)) {}
|
||||||
|
|
||||||
Status TFSavedModelAPI::Load(
|
Status TFSavedModelAPI::Load(
|
||||||
const std::string& directory,
|
const std::string& directory,
|
||||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||||
ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) {
|
ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) {
|
||||||
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
|
// TODO(bmzhao): Add support for loading a TF1 SavedModel.
|
||||||
return errors::Unimplemented(
|
if (tags) {
|
||||||
"TFSavedModelAPIImpl loading is unimplemented currently");
|
return errors::Unimplemented(
|
||||||
|
"Loading saved models with explicit tags will be supported in the "
|
||||||
|
"future");
|
||||||
|
}
|
||||||
|
|
||||||
|
SavedModelV2Bundle bundle;
|
||||||
|
TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle));
|
||||||
|
|
||||||
|
// TODO(bmzhao): Mangle loaded function names so that different
|
||||||
|
// models loaded in the same runtime Context don't clobber eachother.
|
||||||
|
// This occurs in python here:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
|
||||||
|
|
||||||
|
RevivedObjectMap revived_objects;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ReviveObjects(bundle.meta_graph_def(), context, &revived_objects));
|
||||||
|
|
||||||
|
// TODO(bmzhao): When we later add support for loading resources, we need to
|
||||||
|
// handle the case where materializing a function's captures requires invoking
|
||||||
|
// other functions. This occurs when retrieving the resource handle for a
|
||||||
|
// TrackableResource:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
|
||||||
|
// This requires restoring functions in a topological sort order by capture
|
||||||
|
// dependencies.
|
||||||
|
ConcreteFunctionMap function_map;
|
||||||
|
TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects,
|
||||||
|
context, &function_map));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
RestoreCheckpoint(&bundle, revived_objects, directory, context));
|
||||||
|
|
||||||
|
out->reset(new TFSavedModelAPI(directory, std::move(bundle),
|
||||||
|
std::move(revived_objects),
|
||||||
|
std::move(function_map)));
|
||||||
|
return Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,14 +16,19 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
||||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||||
|
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -63,8 +68,19 @@ class TFSavedModelAPI : public SavedModelAPI {
|
|||||||
~TFSavedModelAPI() override = default;
|
~TFSavedModelAPI() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TFSavedModelAPI() = default;
|
TFSavedModelAPI(
|
||||||
std::vector<ConcreteFunction> functions_;
|
const std::string& directory, SavedModelV2Bundle bundle,
|
||||||
|
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||||
|
revived_objects,
|
||||||
|
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||||
|
concrete_functions);
|
||||||
|
|
||||||
|
std::string directory_;
|
||||||
|
SavedModelV2Bundle bundle_;
|
||||||
|
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||||
|
revived_objects_;
|
||||||
|
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||||
|
concrete_functions_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,10 +16,15 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/stringpiece.h"
|
#include "tensorflow/core/platform/stringpiece.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
@ -92,12 +97,51 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
|||||||
TF_SavedModel* saved_model =
|
TF_SavedModel* saved_model =
|
||||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||||
|
|
||||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
TF_ConcreteFunction* compute_fn =
|
||||||
// like loading a ConcreteFunction. This test at least checks that the
|
TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
|
||||||
// C API builds and can be minimally run.
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
|
|
||||||
|
|
||||||
|
TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp(compute_fn, status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
const TF_TensorHandleList* captures =
|
||||||
|
TF_ConcreteFunctionGetCaptures(compute_fn);
|
||||||
|
|
||||||
|
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
|
||||||
|
// inputs + outputs a function has.
|
||||||
|
std::vector<TFE_TensorHandle*> compute_fn_inputs;
|
||||||
|
TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
|
||||||
|
TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
|
||||||
|
compute_fn_inputs.reserve(2 + TF_TensorHandleListSize(captures));
|
||||||
|
compute_fn_inputs.push_back(input_a);
|
||||||
|
compute_fn_inputs.push_back(input_b);
|
||||||
|
for (int i = 0; i < TF_TensorHandleListSize(captures); ++i) {
|
||||||
|
compute_fn_inputs.push_back(TF_TensorHandleListGet(captures, i));
|
||||||
|
}
|
||||||
|
TFE_OpAddInputList(compute_fn_op, compute_fn_inputs.data(),
|
||||||
|
compute_fn_inputs.size(), status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* compute_fn_outputs[1] = {nullptr};
|
||||||
|
int num_retvals = 1;
|
||||||
|
|
||||||
|
TFE_Execute(compute_fn_op, &compute_fn_outputs[0], &num_retvals, status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
EXPECT_EQ(TF_NumDims(result), 0);
|
||||||
|
float output_value = *static_cast<float*>(TF_TensorData(result));
|
||||||
|
// (1 + 2) * (2 + 1) / 3 + 5 should be 8
|
||||||
|
EXPECT_FLOAT_EQ(output_value, 8.0);
|
||||||
|
|
||||||
|
TF_DeleteTensor(result);
|
||||||
|
TFE_DeleteTensorHandle(compute_fn_outputs[0]);
|
||||||
|
TFE_DeleteTensorHandle(input_a);
|
||||||
|
TFE_DeleteTensorHandle(input_b);
|
||||||
|
TFE_DeleteOp(compute_fn_op);
|
||||||
TF_DeleteSavedModel(saved_model);
|
TF_DeleteSavedModel(saved_model);
|
||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
|
@ -86,11 +86,7 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
|||||||
std::unique_ptr<SavedModelAPI> model =
|
std::unique_ptr<SavedModelAPI> model =
|
||||||
SavedModelAPI::Load(model_dir, *runtime, &status);
|
SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||||
|
|
||||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
EXPECT_EQ(status.code(), TF_OK) << status.message();
|
||||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
|
||||||
// like loading a ConcreteFunction. This test at least checks that the
|
|
||||||
// C API builds and can be minimally run.
|
|
||||||
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED) << status.message();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
||||||
|
Loading…
Reference in New Issue
Block a user