diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 5c51e26f925..2ded784882b 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -155,6 +155,7 @@ cc_library( "saved_model_api_type.h", ], deps = [ + "//tensorflow/c:conversion_macros", "//tensorflow/c/experimental/saved_model/core:saved_model_api", ], ) diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index 629610dbe29..9614e507646 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx, if (!status->status.ok()) { return nullptr; } - return new TF_SavedModel{std::move(result)}; + return tensorflow::wrap(result.release()); } TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, @@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, if (!status->status.ok()) { return nullptr; } - return new TF_SavedModel{std::move(result)}; + return tensorflow::wrap(result.release()); } -void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; } +void TF_DeleteSavedModel(TF_SavedModel* model) { + delete tensorflow::unwrap(model); +} TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, const char* function_path, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; tensorflow::Status get_function_status = - model->saved_model->GetFunction(function_path, &result); + tensorflow::unwrap(model)->GetFunction(function_path, &result); status->status.Update(get_function_status); if (!get_function_status.ok()) { return nullptr; @@ -82,7 +84,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( TF_SavedModel* model, const char* signature_def_key, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; tensorflow::Status get_function_status = - model->saved_model->GetSignatureDefFunction(signature_def_key, &result); + tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key, + &result); status->status.Update(get_function_status); if (!get_function_status.ok()) { return nullptr; @@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( } TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) { - return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()}; + return new TF_ConcreteFunctionList{ + tensorflow::unwrap(model)->ListFunctions()}; } } // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h b/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h index 9e2d1117463..380c3703426 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h @@ -18,13 +18,18 @@ limitations under the License. #include <memory> +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" // Internal structures used by the SavedModel C API. These are likely to change // and should not be depended on. -struct TF_SavedModel { - std::unique_ptr<tensorflow::SavedModelAPI> saved_model; -}; +typedef struct TF_SavedModel TF_SavedModel; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel) + +} // namespace tensorflow #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_