Initial checkin of SavedModelAPI C++ header-only API. Tensor, TensorHandle and ConreteFunction::Run will be added in a subsequent change. See RFC https://github.com/tensorflow/community/pull/207.

PiperOrigin-RevId: 309259741
Change-Id: Ic88c3d2b2de83d59774b305b8e0126d792624d78
This commit is contained in:
Brian Zhao 2020-04-30 11:01:52 -07:00 committed by TensorFlower Gardener
parent bac40c0346
commit 3a61ea880c
13 changed files with 803 additions and 4 deletions

View File

@ -66,7 +66,7 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
char* function_path,
const char* function_path,
TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
@ -79,7 +79,7 @@ TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
}
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, char* signature_def_key, TF_Status* status) {
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);

View File

@ -80,7 +80,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model);
// "conceptually" bound to `model`. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
TF_SavedModel* model, char* function_path, TF_Status* status);
TF_SavedModel* model, const char* function_path, TF_Status* status);
// Retrieve a function from the TF SavedModel via a SignatureDef key.
//
@ -94,7 +94,7 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
// TF_ConcreteFunction instance. Once `model` is deleted, all
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, char* signature_def_key, TF_Status* status);
TF_SavedModel* model, const char* signature_def_key, TF_Status* status);
// Returns a list of all ConcreteFunctions stored in this SavedModel.
// The lifetime of the returned list is bound to `model`.

View File

@ -0,0 +1,52 @@
# Experimental C++ APIs for TensorFlow.
# New TF C++ APIs under the tensorflow::cc namespace aim to guarantee ABI stability.
# Users are expected to compile against public c++ headers, and link against
# libtensorflow (https://www.tensorflow.org/install/lang_c).
# We aim to achieve ABI stability in new C++ APIs by only using types
# on the API surface that:
# 1. Have a header-only implementation
# 2. Are std:: types
# 3. Wrap an opaque C type
package(
# This is intentionally public
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "runtime",
hdrs = [
"runtime.h",
],
deps = [
":status",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)
cc_library(
name = "runtime_builder",
hdrs = [
"runtime_builder.h",
],
deps = [
":runtime",
":status",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)
cc_library(
name = "status",
hdrs = [
"status.h",
],
deps = [
"//tensorflow/c:tf_status",
],
)

View File

@ -0,0 +1,68 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
#include <memory>
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace cc {
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
// resources, threadpools, etc. Clients are expected to construct a Runtime
// object through tensorflow::cc::RuntimeBuilder::Build, after setting any
// relevant configuration options. Many Tensorflow functions take a reference to
// the runtime as an argument (eg: tensorflow::cc::SavedModelAPI::Load), and
// may have different implementations depending on the runtime. For many of
// these Runtime-attached objects (such as tensorflow::cc::TensorHandle), the
// Runtime must outlive these objects.
class Runtime {
public:
// Runtime is movable, but not copyable.
Runtime(Runtime&&) = default;
Runtime& operator=(Runtime&&) = default;
private:
friend class RuntimeBuilder;
friend class SavedModelAPI;
// Wraps a TFE_Context. Takes ownership of ctx.
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
// Deletes the currently wrapped TFE_Context, swaps it with ctx,
// and takes ownership of ctx.
void Reset(TFE_Context* ctx) { ctx_.reset(ctx); }
// Returns the TFE_Context that this object wraps. This object
// retains ownership of the pointer.
TFE_Context* GetTFEContext() const { return ctx_.get(); }
// Runtime is not copyable
Runtime(const Runtime&) = delete;
Runtime& operator=(const Runtime&) = delete;
struct TFEContextDeleter {
void operator()(TFE_Context* p) const { TFE_DeleteContext(p); }
};
std::unique_ptr<TFE_Context, TFEContextDeleter> ctx_;
};
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_

View File

@ -0,0 +1,84 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
#include <memory>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow {
namespace cc {
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
// Use this to set configuration options, like threadpool size, etc.
class RuntimeBuilder {
public:
RuntimeBuilder() : options_(TFE_NewContextOptions()) {}
// If `use_tfrt` is true, we will use the new Tensorflow Runtime
// (https://blog.tensorflow.org/2020/04/tfrt-new-tensorflow-runtime.html) as
// our runtime implementation.
RuntimeBuilder& SetUseTFRT(bool use_tfrt);
// Build a Tensorflow Runtime.
//
// Params:
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// unique_ptr<tensorflow::cc::Runtime>.
std::unique_ptr<Runtime> Build(Status* status);
// RuntimeBuilder is movable, but not copyable.
RuntimeBuilder(RuntimeBuilder&&) = default;
RuntimeBuilder& operator=(RuntimeBuilder&&) = default;
private:
// RuntimeBuilder is not copyable
RuntimeBuilder(const RuntimeBuilder&) = delete;
RuntimeBuilder& operator=(const RuntimeBuilder&) = delete;
struct TFEContextOptionsDeleter {
void operator()(TFE_ContextOptions* p) const {
TFE_DeleteContextOptions(p);
}
};
std::unique_ptr<TFE_ContextOptions, TFEContextOptionsDeleter> options_;
};
inline RuntimeBuilder& RuntimeBuilder::SetUseTFRT(bool use_tfrt) {
TFE_ContextOptionsSetTfrt(options_.get(), use_tfrt);
return *this;
}
inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
TFE_Context* result = TFE_NewContext(options_.get(), status->GetTFStatus());
if (!status->ok()) {
return nullptr;
}
// We can't use std::make_unique here because of its interaction with a
// private constructor: https://abseil.io/tips/134
return std::unique_ptr<Runtime>(new Runtime(result));
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_

View File

@ -0,0 +1,93 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
#include <memory>
#include <string>
#include "tensorflow/c/tf_status.h"
namespace tensorflow {
namespace cc {
// Status is a wrapper around an error code and an optional error message.
// The set of error codes are defined here:
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/c/tf_status.h#L39-L60
// Many Tensorflow APIs return a Status, or take a Status as an out parameter.
// Clients should check for status.ok() after calling these APIs, and either
// handle or propagate the error appropriately.
// TODO(bmzhao): Add a detailed code example before moving out of experimental.
class Status {
public:
// Create a success status
Status() : status_(TF_NewStatus()) {}
// Return the status code
TF_Code code() const;
// Returns the error message in Status.
std::string message() const;
// Returns the error message in Status.
bool ok() const;
// Record <code, msg> in Status. Any previous information is lost.
// A common use is to clear a status: SetStatus(TF_OK, "");
void SetStatus(TF_Code code, const std::string& msg);
// Status is movable, but not copyable.
Status(Status&&) = default;
Status& operator=(Status&&) = default;
private:
friend class RuntimeBuilder;
friend class Runtime;
friend class SavedModelAPI;
// Wraps a TF_Status*, and takes ownership of it.
explicit Status(TF_Status* status) : status_(status) {}
// Status is not copyable
Status(const Status&) = delete;
Status& operator=(const Status&) = delete;
// Returns the TF_Status that this object wraps. This object
// retains ownership of the pointer.
TF_Status* GetTFStatus() const { return status_.get(); }
struct TFStatusDeleter {
void operator()(TF_Status* p) const { TF_DeleteStatus(p); }
};
std::unique_ptr<TF_Status, TFStatusDeleter> status_;
};
inline TF_Code Status::code() const { return TF_GetCode(status_.get()); }
inline std::string Status::message() const {
return std::string(TF_Message(status_.get()));
}
inline bool Status::ok() const { return code() == TF_OK; }
inline void Status::SetStatus(TF_Code code, const std::string& msg) {
TF_SetStatus(status_.get(), code, msg.c_str());
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_

View File

@ -0,0 +1,58 @@
# Experimental C++ SavedModel Header Only APIs. See RFC
# https://github.com/tensorflow/community/pull/207
package(
# This is intentionally public
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "concrete_function",
hdrs = [
"concrete_function.h",
],
deps = [
":function_metadata",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/cc/experimental/base/public:status",
],
)
cc_library(
name = "concrete_function_list",
hdrs = [
"concrete_function_list.h",
],
deps = [
":concrete_function",
"//tensorflow/c/experimental/saved_model/public:concrete_function_list",
],
)
cc_library(
name = "function_metadata",
hdrs = [
"function_metadata.h",
],
deps = [
"//tensorflow/c/experimental/saved_model/public:function_metadata",
],
)
cc_library(
name = "saved_model_api",
hdrs = [
"saved_model_api.h",
],
deps = [
":concrete_function",
":concrete_function_list",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:status",
],
)

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
namespace tensorflow {
namespace cc {
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
class ConcreteFunction final {
public:
// TODO(bmzhao): Adding ConcreteFunction::Run in subsequent CL, since
// it depends on tensorflow::cc::Tensor and tensorflow::cc::TensorHandle
// Returns FunctionMetadata associated with this ConcreteFunction.
const FunctionMetadata* GetFunctionMetadata();
private:
friend class SavedModelAPI;
friend class ConcreteFunctionList;
// TODO(bmzhao): Consider adding a macro for wrapping/unwrapping
// when moving out of experimental.
static ConcreteFunction* wrap(TF_ConcreteFunction* p) {
return reinterpret_cast<ConcreteFunction*>(p);
}
static TF_ConcreteFunction* unwrap(ConcreteFunction* p) {
return reinterpret_cast<TF_ConcreteFunction*>(p);
}
};
inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
return FunctionMetadata::wrap(TF_ConcreteFunctionGetMetadata(unwrap(this)));
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_

View File

@ -0,0 +1,61 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
#include <vector>
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
namespace tensorflow {
namespace cc {
// ConcreteFunctionList helps convert an opaque pointer to an array of
// ConcreteFunction pointers to a std::vector.
class ConcreteFunctionList {
public:
// Converts this object to a std::vector<ConcreteFunction*>
std::vector<ConcreteFunction*> ToVector();
private:
friend class SavedModelAPI;
// Wraps a TF_ConcreteFunctionList. Takes ownership of list.
explicit ConcreteFunctionList(TF_ConcreteFunctionList* list) : list_(list) {}
struct TFConcreteFunctionListDeleter {
void operator()(TF_ConcreteFunctionList* p) const {
TF_DeleteConcreteFunctionList(p);
}
};
std::unique_ptr<TF_ConcreteFunctionList, TFConcreteFunctionListDeleter> list_;
};
inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
int size = TF_ConcreteFunctionListSize(list_.get());
std::vector<ConcreteFunction*> result;
result.reserve(size);
for (int i = 0; i < size; ++i) {
result.push_back(
ConcreteFunction::wrap(TF_ConcreteFunctionListGet(list_.get(), i)));
}
return result;
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
#include <memory>
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
namespace tensorflow {
namespace cc {
// FunctionMetadata stores additional function information, including
// optional signaturedef feeds/fetches (for TF1-based ConcreteFunctions),
// a valid function path (for TF2-based ConcreteFunctions), and
// the types + number of inputs and outputs.
class FunctionMetadata final {
// TODO(bmzhao): Add getters here as necessary.
private:
friend class ConcreteFunction;
static FunctionMetadata* wrap(TF_FunctionMetadata* p) {
return reinterpret_cast<FunctionMetadata*>(p);
}
static TF_FunctionMetadata* unwrap(FunctionMetadata* p) {
return reinterpret_cast<TF_FunctionMetadata*>(p);
}
};
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_

View File

@ -0,0 +1,160 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
namespace tensorflow {
namespace cc {
// SavedModelAPI offers a way to load Tensorflow Saved Models
// (https://www.tensorflow.org/guide/saved_model) and execute saved
// tf.functions or legacy SignatureDefs in a TF2-idiomatic fashion.
// See RFC 207
// (https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md)
// TODO(bmzhao): Add an e2e example here, once ConcreteFunction::Run is added.
class SavedModelAPI {
public:
// Load a SavedModel from `dirname`.
//
// Params:
// saved_model_path - A directory filepath that the SavedModel is at.
// runtime - A runtime used to load SavedModelAPI. `runtime` must outlive the
// returned TF_SavedModel pointer.
// tags - Optional set of tags. If tags = nullptr, we expect the SavedModel
// to contain a single Metagraph (as for those exported from TF2's
// `tf.saved_model.save`). If tags != nullptr, we load the metagraph
// matching the tags:
// https://github.com/tensorflow/tensorflow/blob/428cdeda09aef81e958eeb274b83d27ad635b57b/tensorflow/core/protobuf/meta_graph.proto#L50-L56
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr.
static std::unique_ptr<SavedModelAPI> Load(
const std::string& saved_model_path, const Runtime& runtime,
Status* status, const std::unordered_set<std::string>* tags = nullptr);
// Retrieve a function from the TF2 SavedModel via function path.
//
// Params:
// function_path - A string containing the path from the root saved python
// object to a tf.function method.
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
// is bound to SavedModelAPI it was loaded from.
ConcreteFunction* GetConcreteFunction(const std::string& function_path,
Status* status);
// Retrieve a function from the TF SavedModel via a SignatureDef key.
//
// Params:
// signature_def_key - String key of SignatureDef map of a SavedModel:
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
// status - Set to OK on success and an appropriate error on failure.
// Returns:
// If status is not OK, returns nullptr. Otherwise, returns a
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
// is bound to SavedModelAPI it was loaded from.
ConcreteFunction* GetSignatureDefFunction(const std::string& function_path,
Status* status);
// Lists all Conrete Functions available from the SavedModel.
std::vector<ConcreteFunction*> ListFunctions();
// SavedModelAPI is movable, but not copyable.
SavedModelAPI(SavedModelAPI&&) = default;
SavedModelAPI& operator=(SavedModelAPI&&) = default;
private:
SavedModelAPI(const SavedModelAPI&) = delete;
SavedModelAPI& operator=(const SavedModelAPI&) = delete;
explicit SavedModelAPI(TF_SavedModel* model) : saved_model_(model) {}
struct TFSavedModelDeleter {
void operator()(TF_SavedModel* p) const { TF_DeleteSavedModel(p); }
};
std::unique_ptr<TF_SavedModel, TFSavedModelDeleter> saved_model_;
};
inline std::unique_ptr<SavedModelAPI> SavedModelAPI::Load(
const std::string& saved_model_path, const Runtime& runtime, Status* status,
const std::unordered_set<std::string>* tags) {
TF_SavedModel* saved_model = nullptr;
if (tags == nullptr) {
saved_model =
TF_LoadSavedModel(saved_model_path.c_str(), runtime.GetTFEContext(),
status->GetTFStatus());
} else {
std::vector<const char*> tags_vector;
tags_vector.reserve(tags->size());
for (const std::string& tag : *tags) {
tags_vector.push_back(tag.c_str());
}
saved_model = TF_LoadSavedModelWithTags(
saved_model_path.c_str(), runtime.GetTFEContext(), tags_vector.data(),
tags_vector.size(), status->GetTFStatus());
}
if (!status->ok()) {
return nullptr;
}
// We can't use std::make_unique here because of its interaction with a
// private constructor: https://abseil.io/tips/134
return std::unique_ptr<SavedModelAPI>(new SavedModelAPI(saved_model));
}
inline ConcreteFunction* SavedModelAPI::GetConcreteFunction(
const std::string& function_path, Status* status) {
TF_ConcreteFunction* function = TF_GetSavedModelConcreteFunction(
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
if (!status->ok()) {
return nullptr;
}
return ConcreteFunction::wrap(function);
}
inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction(
const std::string& function_path, Status* status) {
TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction(
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
if (!status->ok()) {
return nullptr;
}
return ConcreteFunction::wrap(function);
}
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get()));
return list.ToVector();
}
} // namespace cc
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_

View File

@ -0,0 +1,22 @@
# Tests for the C++ header-only SavedModelAPI.
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
)
tf_cc_test(
name = "saved_model_api_test",
srcs = [
"saved_model_api_test.cc",
],
deps = [
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:runtime_builder",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/saved_model/experimental/public:saved_model_api",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,97 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/experimental/public/saved_model_api.h"
#include <memory>
#include <string>
#include <unordered_set>
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestData[] = "cc/saved_model/testdata";
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
kTestData, saved_model_dir);
}
// This value parameterized test allows us to test both TFRT
// and non TFRT runtimes.
// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
cc::Status status;
cc::RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unordered_set<std::string> tags = {"serve"};
std::unique_ptr<cc::SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
}
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
cc::Status status;
cc::RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unique_ptr<cc::SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
CPPSavedModelAPITest, ::testing::Bool());
} // namespace
} // namespace tensorflow