From bfeba4d9ec9acc70453cf2c5b6dddb483934f4b3 Mon Sep 17 00:00:00 2001 From: Brian Zhao Date: Tue, 14 Jul 2020 16:48:14 -0700 Subject: [PATCH] Experimental implementation of saved model C API. This supports loading constants, unpartitioned variables, and tf.functions annotated with an input signature from a TF 2 saved model. See RFC 207: https://github.com/tensorflow/community/pull/207 Future CLs will flesh out some of the missing pieces (specifying a tag, loading models with resources/assets, batching tensor restores per device, etc). PiperOrigin-RevId: 321262527 Change-Id: I76dbe4617acd7bdbe5f093e6e22b328842a65780 --- .../c/experimental/saved_model/core/BUILD | 15 + .../saved_model/core/tf_saved_model_api.cc | 333 +++++++++++++++++- .../saved_model/core/tf_saved_model_api.h | 20 +- .../internal/saved_model_api_test.cc | 54 ++- .../tests/saved_model_api_test.cc | 6 +- 5 files changed, 406 insertions(+), 22 deletions(-) diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 5931e229e28..38bdbee1fdc 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -113,8 +113,23 @@ cc_library( deps = [ ":concrete_function", ":saved_model_api", + ":saved_model_utils", + "//tensorflow/c:tensor_interface", "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core/ops:restore_ops", + "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", + "//tensorflow/c/experimental/saved_model/core/revived_types:variable", + "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index 225ba1db9f4..c22f8d86174 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -15,47 +15,360 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h" +#include #include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" +#include "tensorflow/cc/saved_model/bundle_v2.h" +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" namespace tensorflow { +// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary +using FunctionDefMap = + std::unordered_map; + +// Maps from a Nodedef's name to its corresponding AttrValues, for a given +// Graphdef +using NodeAttrMap = + std::unordered_map; + +// Maps from Node ID to an "Revived Object" implementing +// "TensorHandleConvertible" +using RevivedObjectMap = + std::unordered_map>; + +// Maps from a functiondef's name to the corresponding "TFConcreteFunction" +using ConcreteFunctionMap = + std::unordered_map>; + +namespace { + +Status ConstantFromSavedConstant( + ImmediateExecutionContext* ctx, + const tensorflow::SavedConstant& saved_constant, + const NodeAttrMap& node_attr_map, std::unique_ptr* output) { + const std::string& const_op_name = saved_constant.operation(); + const auto& node_name_and_attrs = node_attr_map.find(const_op_name); + if (node_name_and_attrs == node_attr_map.end()) { + return errors::FailedPrecondition( + "Unable to find Const operation with name'", const_op_name, + "' in SavedModel graphdef"); + } + const AttrValueMap* attrs = node_name_and_attrs->second; + const auto& attr_name_and_value = attrs->find("value"); + if (attr_name_and_value == attrs->end()) { + return errors::FailedPrecondition("Unable to find Const operation '", + const_op_name, "'s value attribute"); + } + const TensorProto& tensor_proto = attr_name_and_value->second.tensor(); + return internal::TensorProtoToConstant(ctx, tensor_proto, output); +} + +// Restores all non-function objects in the SavedModel's object graph. +// This function walks through the metagraph's saved object graph, and +// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and +// SavedResources. These are returned via the `out` parameter. +Status ReviveObjects( + const MetaGraphDef& metagraph, ImmediateExecutionContext* context, + std::unordered_map>* + revived_objects) { + // This is needed to restore "Constant" nodes by looking up their + // "Value" attribute. + NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def()); + + // Iterate through all the saved objects, restoring objects as we go. + // We don't recreate functions until all other objects have been created. + for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) { + const SavedObject& node = metagraph.object_graph_def().nodes(i); + if (node.kind_case() == SavedObject::kVariable) { + std::unique_ptr variable; + TF_RETURN_IF_ERROR( + internal::LoadSavedVariable(context, node.variable(), &variable)); + (*revived_objects)[i] = std::move(variable); + } else if (node.kind_case() == SavedObject::kConstant) { + std::unique_ptr constant; + TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(), + node_attr_map, &constant)); + (*revived_objects)[i] = std::move(constant); + } else if (node.kind_case() == SavedObject::kAsset) { + // TODO(bmzhao): Implement Asset C++ class. This should be just recreating + // the full path to the asset file: + // https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396 + // and storing it as a string tensor: + // https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325 + return errors::Unimplemented("SavedAsset loading is not implemented yet"); + } else if (node.kind_case() == SavedObject::kResource) { + // TODO(bmzhao): Figure out how resource loading works and implement it + return errors::Unimplemented( + "SavedResource loading is not implemented yet"); + } + } + return Status(); +} + +Status ReviveFunctions(const MetaGraphDef& metagraph, + const RevivedObjectMap& revived_objects, + ImmediateExecutionContext* context, + ConcreteFunctionMap* restored_functions) { + const FunctionDefMap function_def_map = + internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library()); + + // Iterate through all objects, only examining functions. + for (const SavedObject& node : metagraph.object_graph_def().nodes()) { + if (node.kind_case() == SavedObject::kBareConcreteFunction) { + const std::string& function_name = + node.bare_concrete_function().concrete_function_name(); + + const SavedConcreteFunction& saved_concrete_function = + metagraph.object_graph_def().concrete_functions().at(function_name); + + const FunctionDef* function_def = function_def_map.at(function_name); + std::unique_ptr concrete_function; + TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction( + saved_concrete_function, function_def, revived_objects, context, + &concrete_function)); + (*restored_functions)[function_name] = std::move(concrete_function); + } else if (node.kind_case() == SavedObject::kFunction) { + // We only allow loading functions that have an annotated input signature, + // which means there is 1:1 correspondence between tf.function + // <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is + // the same restriction that MLIR has: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707 + const SavedFunction& saved_function = node.function(); + if (saved_function.concrete_functions_size() != 1) { + return errors::FailedPrecondition( + "Only tf.functions annotated with an input signature are supported " + "by SavedModelAPI. This means that there should only be a single " + "ConcreteFunction per tf.function"); + } + const std::string& function_name = saved_function.concrete_functions(0); + const SavedConcreteFunction& saved_concrete_function = + metagraph.object_graph_def().concrete_functions().at(function_name); + + const FunctionDef* function_def = function_def_map.at(function_name); + + std::unique_ptr concrete_function; + TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction( + saved_concrete_function, function_def, revived_objects, context, + &concrete_function)); + (*restored_functions)[function_name] = std::move(concrete_function); + } + } + return Status(); +} + +const TrackableObjectGraph::TrackableObject::SerializedTensor* +FindSerializedTensorInTrackable( + const TrackableObjectGraph::TrackableObject& trackable_object, + absl::string_view name) { + for (const auto& maybe_serialized_tensor : trackable_object.attributes()) { + if (maybe_serialized_tensor.name() == name) { + return &maybe_serialized_tensor; + } + } + return nullptr; +} + +// This function reads the Checkpoint embedded in the SavedModel, and calls the +// appropriate Restore ops on each of the variables. +// Note(bmzhao): Conceptually, objects that contain checkpointable state +// implement the "_gather_saveables_for_checkpoint" method +// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983 +// which returns a dict of string key -> EITHER: +// 1. python callable (taking a checkpoint key) returning SaveableObject OR +// 2. variable (partitioned/resource/reference or otherwise) +// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58. +// The string key becomes the "name" attribute of the SerializedTensor proto +// in the TrackableObjectGraph, +// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26 +// And the checkpoint_key is a globally unique string derived from this name: +// https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241 +// SaveableObjects model the information needed to pass to the SaveV2/RestoreV2 +// ops via their SaveSpec members +// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21, +// which contain the "real" checkpoint keys into the TensorBundle SSTable. +// They also contain the logic needed to take the restored tensors from +// RestoreV2 and load them back into the "object" they came from via their +// overridden "restore" method: +// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85 +Status RestoreCheckpoint(SavedModelV2Bundle* bundle, + const RevivedObjectMap& revived_objects, + const std::string& directory, + ImmediateExecutionContext* context) { + // TODO(bmzhao): Batch up all the restores into a single restore op per + // device, following logic in MultiDeviceSaver. + TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore( + [&revived_objects, &directory, context, bundle]( + int node, const TrackableObjectGraph::TrackableObject& trackable) { + if (bundle->saved_object_graph().nodes(node).kind_case() != + SavedObject::kVariable) { + // TODO(bmzhao): This requires using the newly added Save/Restore + // functions from + // https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c + return errors::Unimplemented( + "Restoring non-variable objects has not been implemented yet. "); + } + + Variable* variable = + down_cast(revived_objects.at(node).get()); + + // Restore the tensor's value from the checkpoint + const TrackableObjectGraph::TrackableObject::SerializedTensor* + attribute = + FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE"); + if (attribute == nullptr) { + return errors::FailedPrecondition( + "Could not find SerializedTensor with name VARIABLE_VALUE for " + "saved variable"); + } + + const std::string& checkpoint_key = attribute->checkpoint_key(); + std::string variables_path_prefix = + io::JoinPath(directory, kSavedModelVariablesDirectory, + kSavedModelVariablesFilename); + ImmediateTensorHandlePtr restored_output; + TF_RETURN_IF_ERROR(internal::SingleRestore( + context, variables_path_prefix, checkpoint_key, variable->dtype(), + &restored_output)); + + // Assign the restored tensor's value to the variable + return variable->Assign(restored_output.get()); + })); + + return Status(); +} + +} // namespace + Status TFSavedModelAPI::GetFunction(const std::string& function_path, ConcreteFunction** function) { - // TODO(bmzhao): Add support for retrieving a function. - return errors::Unimplemented( - "Retrieving functions is unimplemented currently"); + const SavedObject* object = + internal::FindNodeAtPath(function_path, bundle_.saved_object_graph()); + if (object == nullptr) { + return errors::NotFound("No saved object found at path ", function_path); + } + + if (object->kind_case() == SavedObject::kBareConcreteFunction) { + *function = + concrete_functions_ + .at(object->bare_concrete_function().concrete_function_name()) + .get(); + } else if (object->kind_case() == SavedObject::kFunction) { + *function = + concrete_functions_.at(object->function().concrete_functions(0)).get(); + } else { + return errors::InvalidArgument(function_path, + " is not a path to a Function."); + } + + return Status(); } Status TFSavedModelAPI::GetSignatureDefFunction( const std::string& signature_def_key, ConcreteFunction** function) { // TODO(bmzhao): Add support for retrieving a signaturedef function. return errors::Unimplemented( - "Retrieving functions is unimplemented currently"); + "Retrieving SignatureDef functions is unimplemented currently"); } std::vector TFSavedModelAPI::ListFunctions() { std::vector result; - result.reserve(functions_.size()); - for (ConcreteFunction& function : functions_) { - result.push_back(&function); + result.reserve(concrete_functions_.size()); + for (auto& index_and_function : concrete_functions_) { + result.push_back(index_and_function.second.get()); } return result; } +TFSavedModelAPI::TFSavedModelAPI( + const std::string& directory, SavedModelV2Bundle bundle, + std::unordered_map> + revived_objects, + std::unordered_map> + concrete_functions) + : directory_(directory), + bundle_(std::move(bundle)), + revived_objects_(std::move(revived_objects)), + concrete_functions_(std::move(concrete_functions)) {} + Status TFSavedModelAPI::Load( const std::string& directory, const absl::optional>& tags, ImmediateExecutionContext* context, std::unique_ptr* out) { - // TODO(bmzhao): Add support for loading a TFSavedModelImpl. - return errors::Unimplemented( - "TFSavedModelAPIImpl loading is unimplemented currently"); + // TODO(bmzhao): Add support for loading a TF1 SavedModel. + if (tags) { + return errors::Unimplemented( + "Loading saved models with explicit tags will be supported in the " + "future"); + } + + SavedModelV2Bundle bundle; + TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle)); + + // TODO(bmzhao): Mangle loaded function names so that different + // models loaded in the same runtime Context don't clobber eachother. + // This occurs in python here: + // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454 + + RevivedObjectMap revived_objects; + TF_RETURN_IF_ERROR( + ReviveObjects(bundle.meta_graph_def(), context, &revived_objects)); + + // TODO(bmzhao): When we later add support for loading resources, we need to + // handle the case where materializing a function's captures requires invoking + // other functions. This occurs when retrieving the resource handle for a + // TrackableResource: + // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240 + // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233 + // This requires restoring functions in a topological sort order by capture + // dependencies. + ConcreteFunctionMap function_map; + TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects, + context, &function_map)); + + TF_RETURN_IF_ERROR( + RestoreCheckpoint(&bundle, revived_objects, directory, context)); + + out->reset(new TFSavedModelAPI(directory, std::move(bundle), + std::move(revived_objects), + std::move(function_map))); + return Status(); } } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h index cc631a9f3ae..fc8e738e86f 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -16,14 +16,19 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ +#include #include +#include #include #include #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" +#include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { @@ -63,8 +68,19 @@ class TFSavedModelAPI : public SavedModelAPI { ~TFSavedModelAPI() override = default; private: - TFSavedModelAPI() = default; - std::vector functions_; + TFSavedModelAPI( + const std::string& directory, SavedModelV2Bundle bundle, + std::unordered_map> + revived_objects, + std::unordered_map> + concrete_functions); + + std::string directory_; + SavedModelV2Bundle bundle_; + std::unordered_map> + revived_objects_; + std::unordered_map> + concrete_functions_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index aa0b00ab847..3d490fe7e08 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -16,10 +16,15 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" #include +#include #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" @@ -92,12 +97,51 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) { TF_SavedModel* saved_model = TF_LoadSavedModel(model_dir.c_str(), ctx, status); - // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. - // That unblocks writing other tests that require a TF_SavedModel*, - // like loading a ConcreteFunction. This test at least checks that the - // C API builds and can be minimally run. - EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TF_ConcreteFunction* compute_fn = + TF_GetSavedModelConcreteFunction(saved_model, "compute", status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp(compute_fn, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + const TF_TensorHandleList* captures = + TF_ConcreteFunctionGetCaptures(compute_fn); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + std::vector compute_fn_inputs; + TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f); + TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f); + compute_fn_inputs.reserve(2 + TF_TensorHandleListSize(captures)); + compute_fn_inputs.push_back(input_a); + compute_fn_inputs.push_back(input_b); + for (int i = 0; i < TF_TensorHandleListSize(captures); ++i) { + compute_fn_inputs.push_back(TF_TensorHandleListGet(captures, i)); + } + TFE_OpAddInputList(compute_fn_op, compute_fn_inputs.data(), + compute_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_TensorHandle* compute_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(compute_fn_op, &compute_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + float output_value = *static_cast(TF_TensorData(result)); + // (1 + 2) * (2 + 1) / 3 + 5 should be 8 + EXPECT_FLOAT_EQ(output_value, 8.0); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(compute_fn_outputs[0]); + TFE_DeleteTensorHandle(input_a); + TFE_DeleteTensorHandle(input_b); + TFE_DeleteOp(compute_fn_op); TF_DeleteSavedModel(saved_model); TF_DeleteStatus(status); TFE_DeleteContext(ctx); diff --git a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc index ad80b74f1d5..cf5f742538e 100644 --- a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc +++ b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc @@ -86,11 +86,7 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModel) { std::unique_ptr model = SavedModelAPI::Load(model_dir, *runtime, &status); - // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. - // That unblocks writing other tests that require a TF_SavedModel*, - // like loading a ConcreteFunction. This test at least checks that the - // C API builds and can be minimally run. - EXPECT_EQ(status.code(), TF_UNIMPLEMENTED) << status.message(); + EXPECT_EQ(status.code(), TF_OK) << status.message(); } INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,