Experimental implementation of saved model C API. This supports loading constants, unpartitioned variables, and tf.functions annotated with an input signature from a TF 2 saved model. See RFC 207: https://github.com/tensorflow/community/pull/207

Future CLs will flesh out some of the missing pieces (specifying a tag, loading models with resources/assets, batching tensor restores per device, etc).

PiperOrigin-RevId: 321262527
Change-Id: I76dbe4617acd7bdbe5f093e6e22b328842a65780
This commit is contained in:
Brian Zhao 2020-07-14 16:48:14 -07:00 committed by TensorFlower Gardener
parent 0aa0c33c58
commit bfeba4d9ec
5 changed files with 406 additions and 22 deletions

View File

@ -113,8 +113,23 @@ cc_library(
deps = [ deps = [
":concrete_function", ":concrete_function",
":saved_model_api", ":saved_model_api",
":saved_model_utils",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -15,47 +15,360 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "absl/algorithm/container.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
namespace tensorflow { namespace tensorflow {
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
using FunctionDefMap =
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
StringPieceHasher>;
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
// Graphdef
using NodeAttrMap =
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>;
// Maps from Node ID to an "Revived Object" implementing
// "TensorHandleConvertible"
using RevivedObjectMap =
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>;
// Maps from a functiondef's name to the corresponding "TFConcreteFunction"
using ConcreteFunctionMap =
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>;
namespace {
Status ConstantFromSavedConstant(
ImmediateExecutionContext* ctx,
const tensorflow::SavedConstant& saved_constant,
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
const std::string& const_op_name = saved_constant.operation();
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
if (node_name_and_attrs == node_attr_map.end()) {
return errors::FailedPrecondition(
"Unable to find Const operation with name'", const_op_name,
"' in SavedModel graphdef");
}
const AttrValueMap* attrs = node_name_and_attrs->second;
const auto& attr_name_and_value = attrs->find("value");
if (attr_name_and_value == attrs->end()) {
return errors::FailedPrecondition("Unable to find Const operation '",
const_op_name, "'s value attribute");
}
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
}
// Restores all non-function objects in the SavedModel's object graph.
// This function walks through the metagraph's saved object graph, and
// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and
// SavedResources. These are returned via the `out` parameter.
Status ReviveObjects(
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
revived_objects) {
// This is needed to restore "Constant" nodes by looking up their
// "Value" attribute.
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
// Iterate through all the saved objects, restoring objects as we go.
// We don't recreate functions until all other objects have been created.
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
const SavedObject& node = metagraph.object_graph_def().nodes(i);
if (node.kind_case() == SavedObject::kVariable) {
std::unique_ptr<Variable> variable;
TF_RETURN_IF_ERROR(
internal::LoadSavedVariable(context, node.variable(), &variable));
(*revived_objects)[i] = std::move(variable);
} else if (node.kind_case() == SavedObject::kConstant) {
std::unique_ptr<Constant> constant;
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
node_attr_map, &constant));
(*revived_objects)[i] = std::move(constant);
} else if (node.kind_case() == SavedObject::kAsset) {
// TODO(bmzhao): Implement Asset C++ class. This should be just recreating
// the full path to the asset file:
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396
// and storing it as a string tensor:
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325
return errors::Unimplemented("SavedAsset loading is not implemented yet");
} else if (node.kind_case() == SavedObject::kResource) {
// TODO(bmzhao): Figure out how resource loading works and implement it
return errors::Unimplemented(
"SavedResource loading is not implemented yet");
}
}
return Status();
}
Status ReviveFunctions(const MetaGraphDef& metagraph,
const RevivedObjectMap& revived_objects,
ImmediateExecutionContext* context,
ConcreteFunctionMap* restored_functions) {
const FunctionDefMap function_def_map =
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
// Iterate through all objects, only examining functions.
for (const SavedObject& node : metagraph.object_graph_def().nodes()) {
if (node.kind_case() == SavedObject::kBareConcreteFunction) {
const std::string& function_name =
node.bare_concrete_function().concrete_function_name();
const SavedConcreteFunction& saved_concrete_function =
metagraph.object_graph_def().concrete_functions().at(function_name);
const FunctionDef* function_def = function_def_map.at(function_name);
std::unique_ptr<TFConcreteFunction> concrete_function;
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
saved_concrete_function, function_def, revived_objects, context,
&concrete_function));
(*restored_functions)[function_name] = std::move(concrete_function);
} else if (node.kind_case() == SavedObject::kFunction) {
// We only allow loading functions that have an annotated input signature,
// which means there is 1:1 correspondence between tf.function
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
// the same restriction that MLIR has:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
const SavedFunction& saved_function = node.function();
if (saved_function.concrete_functions_size() != 1) {
return errors::FailedPrecondition(
"Only tf.functions annotated with an input signature are supported "
"by SavedModelAPI. This means that there should only be a single "
"ConcreteFunction per tf.function");
}
const std::string& function_name = saved_function.concrete_functions(0);
const SavedConcreteFunction& saved_concrete_function =
metagraph.object_graph_def().concrete_functions().at(function_name);
const FunctionDef* function_def = function_def_map.at(function_name);
std::unique_ptr<TFConcreteFunction> concrete_function;
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
saved_concrete_function, function_def, revived_objects, context,
&concrete_function));
(*restored_functions)[function_name] = std::move(concrete_function);
}
}
return Status();
}
const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(
const TrackableObjectGraph::TrackableObject& trackable_object,
absl::string_view name) {
for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
if (maybe_serialized_tensor.name() == name) {
return &maybe_serialized_tensor;
}
}
return nullptr;
}
// This function reads the Checkpoint embedded in the SavedModel, and calls the
// appropriate Restore ops on each of the variables.
// Note(bmzhao): Conceptually, objects that contain checkpointable state
// implement the "_gather_saveables_for_checkpoint" method
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983
// which returns a dict of string key -> EITHER:
// 1. python callable (taking a checkpoint key) returning SaveableObject OR
// 2. variable (partitioned/resource/reference or otherwise)
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58.
// The string key becomes the "name" attribute of the SerializedTensor proto
// in the TrackableObjectGraph,
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26
// And the checkpoint_key is a globally unique string derived from this name:
// https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241
// SaveableObjects model the information needed to pass to the SaveV2/RestoreV2
// ops via their SaveSpec members
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21,
// which contain the "real" checkpoint keys into the TensorBundle SSTable.
// They also contain the logic needed to take the restored tensors from
// RestoreV2 and load them back into the "object" they came from via their
// overridden "restore" method:
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
const RevivedObjectMap& revived_objects,
const std::string& directory,
ImmediateExecutionContext* context) {
// TODO(bmzhao): Batch up all the restores into a single restore op per
// device, following logic in MultiDeviceSaver.
TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore(
[&revived_objects, &directory, context, bundle](
int node, const TrackableObjectGraph::TrackableObject& trackable) {
if (bundle->saved_object_graph().nodes(node).kind_case() !=
SavedObject::kVariable) {
// TODO(bmzhao): This requires using the newly added Save/Restore
// functions from
// https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
return errors::Unimplemented(
"Restoring non-variable objects has not been implemented yet. ");
}
Variable* variable =
down_cast<Variable*>(revived_objects.at(node).get());
// Restore the tensor's value from the checkpoint
const TrackableObjectGraph::TrackableObject::SerializedTensor*
attribute =
FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE");
if (attribute == nullptr) {
return errors::FailedPrecondition(
"Could not find SerializedTensor with name VARIABLE_VALUE for "
"saved variable");
}
const std::string& checkpoint_key = attribute->checkpoint_key();
std::string variables_path_prefix =
io::JoinPath(directory, kSavedModelVariablesDirectory,
kSavedModelVariablesFilename);
ImmediateTensorHandlePtr restored_output;
TF_RETURN_IF_ERROR(internal::SingleRestore(
context, variables_path_prefix, checkpoint_key, variable->dtype(),
&restored_output));
// Assign the restored tensor's value to the variable
return variable->Assign(restored_output.get());
}));
return Status();
}
} // namespace
Status TFSavedModelAPI::GetFunction(const std::string& function_path, Status TFSavedModelAPI::GetFunction(const std::string& function_path,
ConcreteFunction** function) { ConcreteFunction** function) {
// TODO(bmzhao): Add support for retrieving a function. const SavedObject* object =
return errors::Unimplemented( internal::FindNodeAtPath(function_path, bundle_.saved_object_graph());
"Retrieving functions is unimplemented currently"); if (object == nullptr) {
return errors::NotFound("No saved object found at path ", function_path);
}
if (object->kind_case() == SavedObject::kBareConcreteFunction) {
*function =
concrete_functions_
.at(object->bare_concrete_function().concrete_function_name())
.get();
} else if (object->kind_case() == SavedObject::kFunction) {
*function =
concrete_functions_.at(object->function().concrete_functions(0)).get();
} else {
return errors::InvalidArgument(function_path,
" is not a path to a Function.");
}
return Status();
} }
Status TFSavedModelAPI::GetSignatureDefFunction( Status TFSavedModelAPI::GetSignatureDefFunction(
const std::string& signature_def_key, ConcreteFunction** function) { const std::string& signature_def_key, ConcreteFunction** function) {
// TODO(bmzhao): Add support for retrieving a signaturedef function. // TODO(bmzhao): Add support for retrieving a signaturedef function.
return errors::Unimplemented( return errors::Unimplemented(
"Retrieving functions is unimplemented currently"); "Retrieving SignatureDef functions is unimplemented currently");
} }
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() { std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
std::vector<ConcreteFunction*> result; std::vector<ConcreteFunction*> result;
result.reserve(functions_.size()); result.reserve(concrete_functions_.size());
for (ConcreteFunction& function : functions_) { for (auto& index_and_function : concrete_functions_) {
result.push_back(&function); result.push_back(index_and_function.second.get());
} }
return result; return result;
} }
TFSavedModelAPI::TFSavedModelAPI(
const std::string& directory, SavedModelV2Bundle bundle,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects,
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions)
: directory_(directory),
bundle_(std::move(bundle)),
revived_objects_(std::move(revived_objects)),
concrete_functions_(std::move(concrete_functions)) {}
Status TFSavedModelAPI::Load( Status TFSavedModelAPI::Load(
const std::string& directory, const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags, const absl::optional<std::unordered_set<std::string>>& tags,
ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) { ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) {
// TODO(bmzhao): Add support for loading a TFSavedModelImpl. // TODO(bmzhao): Add support for loading a TF1 SavedModel.
if (tags) {
return errors::Unimplemented( return errors::Unimplemented(
"TFSavedModelAPIImpl loading is unimplemented currently"); "Loading saved models with explicit tags will be supported in the "
"future");
}
SavedModelV2Bundle bundle;
TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle));
// TODO(bmzhao): Mangle loaded function names so that different
// models loaded in the same runtime Context don't clobber eachother.
// This occurs in python here:
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
RevivedObjectMap revived_objects;
TF_RETURN_IF_ERROR(
ReviveObjects(bundle.meta_graph_def(), context, &revived_objects));
// TODO(bmzhao): When we later add support for loading resources, we need to
// handle the case where materializing a function's captures requires invoking
// other functions. This occurs when retrieving the resource handle for a
// TrackableResource:
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
// This requires restoring functions in a topological sort order by capture
// dependencies.
ConcreteFunctionMap function_map;
TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects,
context, &function_map));
TF_RETURN_IF_ERROR(
RestoreCheckpoint(&bundle, revived_objects, directory, context));
out->reset(new TFSavedModelAPI(directory, std::move(bundle),
std::move(revived_objects),
std::move(function_map)));
return Status();
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,14 +16,19 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
namespace tensorflow { namespace tensorflow {
@ -63,8 +68,19 @@ class TFSavedModelAPI : public SavedModelAPI {
~TFSavedModelAPI() override = default; ~TFSavedModelAPI() override = default;
private: private:
TFSavedModelAPI() = default; TFSavedModelAPI(
std::vector<ConcreteFunction> functions_; const std::string& directory, SavedModelV2Bundle bundle,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects,
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions);
std::string directory_;
SavedModelV2Bundle bundle_;
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects_;
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,10 +16,15 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include <string> #include <string>
#include <vector>
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -92,12 +97,51 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
TF_SavedModel* saved_model = TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status); TF_LoadSavedModel(model_dir.c_str(), ctx, status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented. EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// That unblocks writing other tests that require a TF_SavedModel*, TF_ConcreteFunction* compute_fn =
// like loading a ConcreteFunction. This test at least checks that the TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
// C API builds and can be minimally run. EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp(compute_fn, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const TF_TensorHandleList* captures =
TF_ConcreteFunctionGetCaptures(compute_fn);
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
// inputs + outputs a function has.
std::vector<TFE_TensorHandle*> compute_fn_inputs;
TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
compute_fn_inputs.reserve(2 + TF_TensorHandleListSize(captures));
compute_fn_inputs.push_back(input_a);
compute_fn_inputs.push_back(input_b);
for (int i = 0; i < TF_TensorHandleListSize(captures); ++i) {
compute_fn_inputs.push_back(TF_TensorHandleListGet(captures, i));
}
TFE_OpAddInputList(compute_fn_op, compute_fn_inputs.data(),
compute_fn_inputs.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* compute_fn_outputs[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(compute_fn_op, &compute_fn_outputs[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_NumDims(result), 0);
float output_value = *static_cast<float*>(TF_TensorData(result));
// (1 + 2) * (2 + 1) / 3 + 5 should be 8
EXPECT_FLOAT_EQ(output_value, 8.0);
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(compute_fn_outputs[0]);
TFE_DeleteTensorHandle(input_a);
TFE_DeleteTensorHandle(input_b);
TFE_DeleteOp(compute_fn_op);
TF_DeleteSavedModel(saved_model); TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status); TF_DeleteStatus(status);
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);

View File

@ -86,11 +86,7 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
std::unique_ptr<SavedModelAPI> model = std::unique_ptr<SavedModelAPI> model =
SavedModelAPI::Load(model_dir, *runtime, &status); SavedModelAPI::Load(model_dir, *runtime, &status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented. EXPECT_EQ(status.code(), TF_OK) << status.message();
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED) << status.message();
} }
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests, INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,