Initial checkin of SavedModelAPI C++ header-only API. Tensor, TensorHandle and ConreteFunction::Run will be added in a subsequent change. See RFC https://github.com/tensorflow/community/pull/207.
PiperOrigin-RevId: 309259741 Change-Id: Ic88c3d2b2de83d59774b305b8e0126d792624d78
This commit is contained in:
parent
bac40c0346
commit
3a61ea880c
@ -66,7 +66,7 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
|
||||
|
||||
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
||||
char* function_path,
|
||||
const char* function_path,
|
||||
TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
@ -79,7 +79,7 @@ TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, char* signature_def_key, TF_Status* status) {
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
|
||||
|
@ -80,7 +80,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model);
|
||||
// "conceptually" bound to `model`. Once `model` is deleted, all
|
||||
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
|
||||
TF_SavedModel* model, char* function_path, TF_Status* status);
|
||||
TF_SavedModel* model, const char* function_path, TF_Status* status);
|
||||
|
||||
// Retrieve a function from the TF SavedModel via a SignatureDef key.
|
||||
//
|
||||
@ -94,7 +94,7 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
|
||||
// TF_ConcreteFunction instance. Once `model` is deleted, all
|
||||
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, char* signature_def_key, TF_Status* status);
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status);
|
||||
|
||||
// Returns a list of all ConcreteFunctions stored in this SavedModel.
|
||||
// The lifetime of the returned list is bound to `model`.
|
||||
|
52
tensorflow/cc/experimental/base/public/BUILD
Normal file
52
tensorflow/cc/experimental/base/public/BUILD
Normal file
@ -0,0 +1,52 @@
|
||||
# Experimental C++ APIs for TensorFlow.
|
||||
# New TF C++ APIs under the tensorflow::cc namespace aim to guarantee ABI stability.
|
||||
# Users are expected to compile against public c++ headers, and link against
|
||||
# libtensorflow (https://www.tensorflow.org/install/lang_c).
|
||||
# We aim to achieve ABI stability in new C++ APIs by only using types
|
||||
# on the API surface that:
|
||||
# 1. Have a header-only implementation
|
||||
# 2. Are std:: types
|
||||
# 3. Wrap an opaque C type
|
||||
|
||||
package(
|
||||
# This is intentionally public
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime",
|
||||
hdrs = [
|
||||
"runtime.h",
|
||||
],
|
||||
deps = [
|
||||
":status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime_builder",
|
||||
hdrs = [
|
||||
"runtime_builder.h",
|
||||
],
|
||||
deps = [
|
||||
":runtime",
|
||||
":status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "status",
|
||||
hdrs = [
|
||||
"status.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
],
|
||||
)
|
68
tensorflow/cc/experimental/base/public/runtime.h
Normal file
68
tensorflow/cc/experimental/base/public/runtime.h
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
||||
// resources, threadpools, etc. Clients are expected to construct a Runtime
|
||||
// object through tensorflow::cc::RuntimeBuilder::Build, after setting any
|
||||
// relevant configuration options. Many Tensorflow functions take a reference to
|
||||
// the runtime as an argument (eg: tensorflow::cc::SavedModelAPI::Load), and
|
||||
// may have different implementations depending on the runtime. For many of
|
||||
// these Runtime-attached objects (such as tensorflow::cc::TensorHandle), the
|
||||
// Runtime must outlive these objects.
|
||||
class Runtime {
|
||||
public:
|
||||
// Runtime is movable, but not copyable.
|
||||
Runtime(Runtime&&) = default;
|
||||
Runtime& operator=(Runtime&&) = default;
|
||||
|
||||
private:
|
||||
friend class RuntimeBuilder;
|
||||
friend class SavedModelAPI;
|
||||
|
||||
// Wraps a TFE_Context. Takes ownership of ctx.
|
||||
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
||||
|
||||
// Deletes the currently wrapped TFE_Context, swaps it with ctx,
|
||||
// and takes ownership of ctx.
|
||||
void Reset(TFE_Context* ctx) { ctx_.reset(ctx); }
|
||||
|
||||
// Returns the TFE_Context that this object wraps. This object
|
||||
// retains ownership of the pointer.
|
||||
TFE_Context* GetTFEContext() const { return ctx_.get(); }
|
||||
|
||||
// Runtime is not copyable
|
||||
Runtime(const Runtime&) = delete;
|
||||
Runtime& operator=(const Runtime&) = delete;
|
||||
|
||||
struct TFEContextDeleter {
|
||||
void operator()(TFE_Context* p) const { TFE_DeleteContext(p); }
|
||||
};
|
||||
std::unique_ptr<TFE_Context, TFEContextDeleter> ctx_;
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
84
tensorflow/cc/experimental/base/public/runtime_builder.h
Normal file
84
tensorflow/cc/experimental/base/public/runtime_builder.h
Normal file
@ -0,0 +1,84 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
||||
// Use this to set configuration options, like threadpool size, etc.
|
||||
class RuntimeBuilder {
|
||||
public:
|
||||
RuntimeBuilder() : options_(TFE_NewContextOptions()) {}
|
||||
|
||||
// If `use_tfrt` is true, we will use the new Tensorflow Runtime
|
||||
// (https://blog.tensorflow.org/2020/04/tfrt-new-tensorflow-runtime.html) as
|
||||
// our runtime implementation.
|
||||
RuntimeBuilder& SetUseTFRT(bool use_tfrt);
|
||||
|
||||
// Build a Tensorflow Runtime.
|
||||
//
|
||||
// Params:
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// unique_ptr<tensorflow::cc::Runtime>.
|
||||
std::unique_ptr<Runtime> Build(Status* status);
|
||||
|
||||
// RuntimeBuilder is movable, but not copyable.
|
||||
RuntimeBuilder(RuntimeBuilder&&) = default;
|
||||
RuntimeBuilder& operator=(RuntimeBuilder&&) = default;
|
||||
|
||||
private:
|
||||
// RuntimeBuilder is not copyable
|
||||
RuntimeBuilder(const RuntimeBuilder&) = delete;
|
||||
RuntimeBuilder& operator=(const RuntimeBuilder&) = delete;
|
||||
|
||||
struct TFEContextOptionsDeleter {
|
||||
void operator()(TFE_ContextOptions* p) const {
|
||||
TFE_DeleteContextOptions(p);
|
||||
}
|
||||
};
|
||||
std::unique_ptr<TFE_ContextOptions, TFEContextOptionsDeleter> options_;
|
||||
};
|
||||
|
||||
inline RuntimeBuilder& RuntimeBuilder::SetUseTFRT(bool use_tfrt) {
|
||||
TFE_ContextOptionsSetTfrt(options_.get(), use_tfrt);
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
|
||||
TFE_Context* result = TFE_NewContext(options_.get(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
// We can't use std::make_unique here because of its interaction with a
|
||||
// private constructor: https://abseil.io/tips/134
|
||||
return std::unique_ptr<Runtime>(new Runtime(result));
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
93
tensorflow/cc/experimental/base/public/status.h
Normal file
93
tensorflow/cc/experimental/base/public/status.h
Normal file
@ -0,0 +1,93 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// Status is a wrapper around an error code and an optional error message.
|
||||
// The set of error codes are defined here:
|
||||
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/c/tf_status.h#L39-L60
|
||||
// Many Tensorflow APIs return a Status, or take a Status as an out parameter.
|
||||
// Clients should check for status.ok() after calling these APIs, and either
|
||||
// handle or propagate the error appropriately.
|
||||
// TODO(bmzhao): Add a detailed code example before moving out of experimental.
|
||||
class Status {
|
||||
public:
|
||||
// Create a success status
|
||||
Status() : status_(TF_NewStatus()) {}
|
||||
|
||||
// Return the status code
|
||||
TF_Code code() const;
|
||||
|
||||
// Returns the error message in Status.
|
||||
std::string message() const;
|
||||
|
||||
// Returns the error message in Status.
|
||||
bool ok() const;
|
||||
|
||||
// Record <code, msg> in Status. Any previous information is lost.
|
||||
// A common use is to clear a status: SetStatus(TF_OK, "");
|
||||
void SetStatus(TF_Code code, const std::string& msg);
|
||||
|
||||
// Status is movable, but not copyable.
|
||||
Status(Status&&) = default;
|
||||
Status& operator=(Status&&) = default;
|
||||
|
||||
private:
|
||||
friend class RuntimeBuilder;
|
||||
friend class Runtime;
|
||||
friend class SavedModelAPI;
|
||||
|
||||
// Wraps a TF_Status*, and takes ownership of it.
|
||||
explicit Status(TF_Status* status) : status_(status) {}
|
||||
|
||||
// Status is not copyable
|
||||
Status(const Status&) = delete;
|
||||
Status& operator=(const Status&) = delete;
|
||||
|
||||
// Returns the TF_Status that this object wraps. This object
|
||||
// retains ownership of the pointer.
|
||||
TF_Status* GetTFStatus() const { return status_.get(); }
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* p) const { TF_DeleteStatus(p); }
|
||||
};
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> status_;
|
||||
};
|
||||
|
||||
inline TF_Code Status::code() const { return TF_GetCode(status_.get()); }
|
||||
|
||||
inline std::string Status::message() const {
|
||||
return std::string(TF_Message(status_.get()));
|
||||
}
|
||||
|
||||
inline bool Status::ok() const { return code() == TF_OK; }
|
||||
|
||||
inline void Status::SetStatus(TF_Code code, const std::string& msg) {
|
||||
TF_SetStatus(status_.get(), code, msg.c_str());
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
58
tensorflow/cc/saved_model/experimental/public/BUILD
Normal file
58
tensorflow/cc/saved_model/experimental/public/BUILD
Normal file
@ -0,0 +1,58 @@
|
||||
# Experimental C++ SavedModel Header Only APIs. See RFC
|
||||
# https://github.com/tensorflow/community/pull/207
|
||||
|
||||
package(
|
||||
# This is intentionally public
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function",
|
||||
hdrs = [
|
||||
"concrete_function.h",
|
||||
],
|
||||
deps = [
|
||||
":function_metadata",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_list",
|
||||
hdrs = [
|
||||
"concrete_function_list.h",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function_list",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata",
|
||||
hdrs = [
|
||||
"function_metadata.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/saved_model/public:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api",
|
||||
hdrs = [
|
||||
"saved_model_api.h",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list",
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
],
|
||||
)
|
@ -0,0 +1,59 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
|
||||
class ConcreteFunction final {
|
||||
public:
|
||||
// TODO(bmzhao): Adding ConcreteFunction::Run in subsequent CL, since
|
||||
// it depends on tensorflow::cc::Tensor and tensorflow::cc::TensorHandle
|
||||
|
||||
// Returns FunctionMetadata associated with this ConcreteFunction.
|
||||
const FunctionMetadata* GetFunctionMetadata();
|
||||
|
||||
private:
|
||||
friend class SavedModelAPI;
|
||||
friend class ConcreteFunctionList;
|
||||
|
||||
// TODO(bmzhao): Consider adding a macro for wrapping/unwrapping
|
||||
// when moving out of experimental.
|
||||
static ConcreteFunction* wrap(TF_ConcreteFunction* p) {
|
||||
return reinterpret_cast<ConcreteFunction*>(p);
|
||||
}
|
||||
static TF_ConcreteFunction* unwrap(ConcreteFunction* p) {
|
||||
return reinterpret_cast<TF_ConcreteFunction*>(p);
|
||||
}
|
||||
};
|
||||
|
||||
inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
|
||||
return FunctionMetadata::wrap(TF_ConcreteFunctionGetMetadata(unwrap(this)));
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunctionList helps convert an opaque pointer to an array of
|
||||
// ConcreteFunction pointers to a std::vector.
|
||||
class ConcreteFunctionList {
|
||||
public:
|
||||
// Converts this object to a std::vector<ConcreteFunction*>
|
||||
std::vector<ConcreteFunction*> ToVector();
|
||||
|
||||
private:
|
||||
friend class SavedModelAPI;
|
||||
// Wraps a TF_ConcreteFunctionList. Takes ownership of list.
|
||||
explicit ConcreteFunctionList(TF_ConcreteFunctionList* list) : list_(list) {}
|
||||
|
||||
struct TFConcreteFunctionListDeleter {
|
||||
void operator()(TF_ConcreteFunctionList* p) const {
|
||||
TF_DeleteConcreteFunctionList(p);
|
||||
}
|
||||
};
|
||||
std::unique_ptr<TF_ConcreteFunctionList, TFConcreteFunctionListDeleter> list_;
|
||||
};
|
||||
|
||||
inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
|
||||
int size = TF_ConcreteFunctionListSize(list_.get());
|
||||
std::vector<ConcreteFunction*> result;
|
||||
result.reserve(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
result.push_back(
|
||||
ConcreteFunction::wrap(TF_ConcreteFunctionListGet(list_.get(), i)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// FunctionMetadata stores additional function information, including
|
||||
// optional signaturedef feeds/fetches (for TF1-based ConcreteFunctions),
|
||||
// a valid function path (for TF2-based ConcreteFunctions), and
|
||||
// the types + number of inputs and outputs.
|
||||
class FunctionMetadata final {
|
||||
// TODO(bmzhao): Add getters here as necessary.
|
||||
private:
|
||||
friend class ConcreteFunction;
|
||||
static FunctionMetadata* wrap(TF_FunctionMetadata* p) {
|
||||
return reinterpret_cast<FunctionMetadata*>(p);
|
||||
}
|
||||
static TF_FunctionMetadata* unwrap(FunctionMetadata* p) {
|
||||
return reinterpret_cast<TF_FunctionMetadata*>(p);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
160
tensorflow/cc/saved_model/experimental/public/saved_model_api.h
Normal file
160
tensorflow/cc/saved_model/experimental/public/saved_model_api.h
Normal file
@ -0,0 +1,160 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
|
||||
// SavedModelAPI offers a way to load Tensorflow Saved Models
|
||||
// (https://www.tensorflow.org/guide/saved_model) and execute saved
|
||||
// tf.functions or legacy SignatureDefs in a TF2-idiomatic fashion.
|
||||
// See RFC 207
|
||||
// (https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md)
|
||||
// TODO(bmzhao): Add an e2e example here, once ConcreteFunction::Run is added.
|
||||
class SavedModelAPI {
|
||||
public:
|
||||
// Load a SavedModel from `dirname`.
|
||||
//
|
||||
// Params:
|
||||
// saved_model_path - A directory filepath that the SavedModel is at.
|
||||
// runtime - A runtime used to load SavedModelAPI. `runtime` must outlive the
|
||||
// returned TF_SavedModel pointer.
|
||||
// tags - Optional set of tags. If tags = nullptr, we expect the SavedModel
|
||||
// to contain a single Metagraph (as for those exported from TF2's
|
||||
// `tf.saved_model.save`). If tags != nullptr, we load the metagraph
|
||||
// matching the tags:
|
||||
// https://github.com/tensorflow/tensorflow/blob/428cdeda09aef81e958eeb274b83d27ad635b57b/tensorflow/core/protobuf/meta_graph.proto#L50-L56
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr.
|
||||
static std::unique_ptr<SavedModelAPI> Load(
|
||||
const std::string& saved_model_path, const Runtime& runtime,
|
||||
Status* status, const std::unordered_set<std::string>* tags = nullptr);
|
||||
|
||||
// Retrieve a function from the TF2 SavedModel via function path.
|
||||
//
|
||||
// Params:
|
||||
// function_path - A string containing the path from the root saved python
|
||||
// object to a tf.function method.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
|
||||
// is bound to SavedModelAPI it was loaded from.
|
||||
ConcreteFunction* GetConcreteFunction(const std::string& function_path,
|
||||
Status* status);
|
||||
|
||||
// Retrieve a function from the TF SavedModel via a SignatureDef key.
|
||||
//
|
||||
// Params:
|
||||
// signature_def_key - String key of SignatureDef map of a SavedModel:
|
||||
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer
|
||||
// is bound to SavedModelAPI it was loaded from.
|
||||
ConcreteFunction* GetSignatureDefFunction(const std::string& function_path,
|
||||
Status* status);
|
||||
|
||||
// Lists all Conrete Functions available from the SavedModel.
|
||||
std::vector<ConcreteFunction*> ListFunctions();
|
||||
|
||||
// SavedModelAPI is movable, but not copyable.
|
||||
SavedModelAPI(SavedModelAPI&&) = default;
|
||||
SavedModelAPI& operator=(SavedModelAPI&&) = default;
|
||||
|
||||
private:
|
||||
SavedModelAPI(const SavedModelAPI&) = delete;
|
||||
SavedModelAPI& operator=(const SavedModelAPI&) = delete;
|
||||
|
||||
explicit SavedModelAPI(TF_SavedModel* model) : saved_model_(model) {}
|
||||
struct TFSavedModelDeleter {
|
||||
void operator()(TF_SavedModel* p) const { TF_DeleteSavedModel(p); }
|
||||
};
|
||||
std::unique_ptr<TF_SavedModel, TFSavedModelDeleter> saved_model_;
|
||||
};
|
||||
|
||||
inline std::unique_ptr<SavedModelAPI> SavedModelAPI::Load(
|
||||
const std::string& saved_model_path, const Runtime& runtime, Status* status,
|
||||
const std::unordered_set<std::string>* tags) {
|
||||
TF_SavedModel* saved_model = nullptr;
|
||||
|
||||
if (tags == nullptr) {
|
||||
saved_model =
|
||||
TF_LoadSavedModel(saved_model_path.c_str(), runtime.GetTFEContext(),
|
||||
status->GetTFStatus());
|
||||
} else {
|
||||
std::vector<const char*> tags_vector;
|
||||
tags_vector.reserve(tags->size());
|
||||
for (const std::string& tag : *tags) {
|
||||
tags_vector.push_back(tag.c_str());
|
||||
}
|
||||
saved_model = TF_LoadSavedModelWithTags(
|
||||
saved_model_path.c_str(), runtime.GetTFEContext(), tags_vector.data(),
|
||||
tags_vector.size(), status->GetTFStatus());
|
||||
}
|
||||
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// We can't use std::make_unique here because of its interaction with a
|
||||
// private constructor: https://abseil.io/tips/134
|
||||
return std::unique_ptr<SavedModelAPI>(new SavedModelAPI(saved_model));
|
||||
}
|
||||
|
||||
inline ConcreteFunction* SavedModelAPI::GetConcreteFunction(
|
||||
const std::string& function_path, Status* status) {
|
||||
TF_ConcreteFunction* function = TF_GetSavedModelConcreteFunction(
|
||||
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return ConcreteFunction::wrap(function);
|
||||
}
|
||||
|
||||
inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction(
|
||||
const std::string& function_path, Status* status) {
|
||||
TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction(
|
||||
saved_model_.get(), function_path.c_str(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return ConcreteFunction::wrap(function);
|
||||
}
|
||||
|
||||
inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
||||
ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get()));
|
||||
return list.ToVector();
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
22
tensorflow/cc/saved_model/experimental/tests/BUILD
Normal file
22
tensorflow/cc/saved_model/experimental/tests/BUILD
Normal file
@ -0,0 +1,22 @@
|
||||
# Tests for the C++ header-only SavedModelAPI.
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
srcs = [
|
||||
"saved_model_api_test.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:runtime_builder",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/saved_model/experimental/public:saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
@ -0,0 +1,97 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/experimental/public/saved_model_api.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
|
||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
|
||||
kTestData, saved_model_dir);
|
||||
}
|
||||
|
||||
// This value parameterized test allows us to test both TFRT
|
||||
// and non TFRT runtimes.
|
||||
// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
|
||||
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unordered_set<std::string> tags = {"serve"};
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
|
||||
}
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
||||
CPPSavedModelAPITest, ::testing::Bool());
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user