diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index cce1b27d9ad..629610dbe29 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -66,7 +66,7 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; } TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, - char* function_path, + const char* function_path, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; tensorflow::Status get_function_status = @@ -79,7 +79,7 @@ TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, } TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( - TF_SavedModel* model, char* signature_def_key, TF_Status* status) { + TF_SavedModel* model, const char* signature_def_key, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; tensorflow::Status get_function_status = model->saved_model->GetSignatureDefFunction(signature_def_key, &result); diff --git a/tensorflow/c/experimental/saved_model/public/saved_model_api.h b/tensorflow/c/experimental/saved_model/public/saved_model_api.h index ad381937e3c..875167bec63 100644 --- a/tensorflow/c/experimental/saved_model/public/saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/saved_model_api.h @@ -80,7 +80,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model); // "conceptually" bound to `model`. Once `model` is deleted, all // `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted. TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction( - TF_SavedModel* model, char* function_path, TF_Status* status); + TF_SavedModel* model, const char* function_path, TF_Status* status); // Retrieve a function from the TF SavedModel via a SignatureDef key. // @@ -94,7 +94,7 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction( // TF_ConcreteFunction instance. Once `model` is deleted, all // `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted. TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( - TF_SavedModel* model, char* signature_def_key, TF_Status* status); + TF_SavedModel* model, const char* signature_def_key, TF_Status* status); // Returns a list of all ConcreteFunctions stored in this SavedModel. // The lifetime of the returned list is bound to `model`. diff --git a/tensorflow/cc/experimental/base/public/BUILD b/tensorflow/cc/experimental/base/public/BUILD new file mode 100644 index 00000000000..a543a347c2e --- /dev/null +++ b/tensorflow/cc/experimental/base/public/BUILD @@ -0,0 +1,52 @@ +# Experimental C++ APIs for TensorFlow. +# New TF C++ APIs under the tensorflow::cc namespace aim to guarantee ABI stability. +# Users are expected to compile against public c++ headers, and link against +# libtensorflow (https://www.tensorflow.org/install/lang_c). +# We aim to achieve ABI stability in new C++ APIs by only using types +# on the API surface that: +# 1. Have a header-only implementation +# 2. Are std:: types +# 3. Wrap an opaque C type + +package( + # This is intentionally public + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "runtime", + hdrs = [ + "runtime.h", + ], + deps = [ + ":status", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + ], +) + +cc_library( + name = "runtime_builder", + hdrs = [ + "runtime_builder.h", + ], + deps = [ + ":runtime", + ":status", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + ], +) + +cc_library( + name = "status", + hdrs = [ + "status.h", + ], + deps = [ + "//tensorflow/c:tf_status", + ], +) diff --git a/tensorflow/cc/experimental/base/public/runtime.h b/tensorflow/cc/experimental/base/public/runtime.h new file mode 100644 index 00000000000..47fd8869647 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/runtime.h @@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ + +#include + +#include "tensorflow/c/eager/c_api_experimental.h" + +namespace tensorflow { +namespace cc { + +// Runtime represents an opaque instance of a Tensorflow runtime, with its own +// resources, threadpools, etc. Clients are expected to construct a Runtime +// object through tensorflow::cc::RuntimeBuilder::Build, after setting any +// relevant configuration options. Many Tensorflow functions take a reference to +// the runtime as an argument (eg: tensorflow::cc::SavedModelAPI::Load), and +// may have different implementations depending on the runtime. For many of +// these Runtime-attached objects (such as tensorflow::cc::TensorHandle), the +// Runtime must outlive these objects. +class Runtime { + public: + // Runtime is movable, but not copyable. + Runtime(Runtime&&) = default; + Runtime& operator=(Runtime&&) = default; + + private: + friend class RuntimeBuilder; + friend class SavedModelAPI; + + // Wraps a TFE_Context. Takes ownership of ctx. + explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {} + + // Deletes the currently wrapped TFE_Context, swaps it with ctx, + // and takes ownership of ctx. + void Reset(TFE_Context* ctx) { ctx_.reset(ctx); } + + // Returns the TFE_Context that this object wraps. This object + // retains ownership of the pointer. + TFE_Context* GetTFEContext() const { return ctx_.get(); } + + // Runtime is not copyable + Runtime(const Runtime&) = delete; + Runtime& operator=(const Runtime&) = delete; + + struct TFEContextDeleter { + void operator()(TFE_Context* p) const { TFE_DeleteContext(p); } + }; + std::unique_ptr ctx_; +}; + +} // namespace cc +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ diff --git a/tensorflow/cc/experimental/base/public/runtime_builder.h b/tensorflow/cc/experimental/base/public/runtime_builder.h new file mode 100644 index 00000000000..ed3c93ae135 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/runtime_builder.h @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" + +namespace tensorflow { +namespace cc { + +// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime. +// Use this to set configuration options, like threadpool size, etc. +class RuntimeBuilder { + public: + RuntimeBuilder() : options_(TFE_NewContextOptions()) {} + + // If `use_tfrt` is true, we will use the new Tensorflow Runtime + // (https://blog.tensorflow.org/2020/04/tfrt-new-tensorflow-runtime.html) as + // our runtime implementation. + RuntimeBuilder& SetUseTFRT(bool use_tfrt); + + // Build a Tensorflow Runtime. + // + // Params: + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // unique_ptr. + std::unique_ptr Build(Status* status); + + // RuntimeBuilder is movable, but not copyable. + RuntimeBuilder(RuntimeBuilder&&) = default; + RuntimeBuilder& operator=(RuntimeBuilder&&) = default; + + private: + // RuntimeBuilder is not copyable + RuntimeBuilder(const RuntimeBuilder&) = delete; + RuntimeBuilder& operator=(const RuntimeBuilder&) = delete; + + struct TFEContextOptionsDeleter { + void operator()(TFE_ContextOptions* p) const { + TFE_DeleteContextOptions(p); + } + }; + std::unique_ptr options_; +}; + +inline RuntimeBuilder& RuntimeBuilder::SetUseTFRT(bool use_tfrt) { + TFE_ContextOptionsSetTfrt(options_.get(), use_tfrt); + return *this; +} + +inline std::unique_ptr RuntimeBuilder::Build(Status* status) { + TFE_Context* result = TFE_NewContext(options_.get(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + // We can't use std::make_unique here because of its interaction with a + // private constructor: https://abseil.io/tips/134 + return std::unique_ptr(new Runtime(result)); +} + +} // namespace cc +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ diff --git a/tensorflow/cc/experimental/base/public/status.h b/tensorflow/cc/experimental/base/public/status.h new file mode 100644 index 00000000000..f91f2caccd8 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/status.h @@ -0,0 +1,93 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ + +#include +#include + +#include "tensorflow/c/tf_status.h" + +namespace tensorflow { +namespace cc { + +// Status is a wrapper around an error code and an optional error message. +// The set of error codes are defined here: +// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/c/tf_status.h#L39-L60 +// Many Tensorflow APIs return a Status, or take a Status as an out parameter. +// Clients should check for status.ok() after calling these APIs, and either +// handle or propagate the error appropriately. +// TODO(bmzhao): Add a detailed code example before moving out of experimental. +class Status { + public: + // Create a success status + Status() : status_(TF_NewStatus()) {} + + // Return the status code + TF_Code code() const; + + // Returns the error message in Status. + std::string message() const; + + // Returns the error message in Status. + bool ok() const; + + // Record in Status. Any previous information is lost. + // A common use is to clear a status: SetStatus(TF_OK, ""); + void SetStatus(TF_Code code, const std::string& msg); + + // Status is movable, but not copyable. + Status(Status&&) = default; + Status& operator=(Status&&) = default; + + private: + friend class RuntimeBuilder; + friend class Runtime; + friend class SavedModelAPI; + + // Wraps a TF_Status*, and takes ownership of it. + explicit Status(TF_Status* status) : status_(status) {} + + // Status is not copyable + Status(const Status&) = delete; + Status& operator=(const Status&) = delete; + + // Returns the TF_Status that this object wraps. This object + // retains ownership of the pointer. + TF_Status* GetTFStatus() const { return status_.get(); } + + struct TFStatusDeleter { + void operator()(TF_Status* p) const { TF_DeleteStatus(p); } + }; + std::unique_ptr status_; +}; + +inline TF_Code Status::code() const { return TF_GetCode(status_.get()); } + +inline std::string Status::message() const { + return std::string(TF_Message(status_.get())); +} + +inline bool Status::ok() const { return code() == TF_OK; } + +inline void Status::SetStatus(TF_Code code, const std::string& msg) { + TF_SetStatus(status_.get(), code, msg.c_str()); +} + +} // namespace cc +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/BUILD b/tensorflow/cc/saved_model/experimental/public/BUILD new file mode 100644 index 00000000000..3e9a671a61f --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/BUILD @@ -0,0 +1,58 @@ +# Experimental C++ SavedModel Header Only APIs. See RFC +# https://github.com/tensorflow/community/pull/207 + +package( + # This is intentionally public + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "concrete_function", + hdrs = [ + "concrete_function.h", + ], + deps = [ + ":function_metadata", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/experimental/saved_model/public:concrete_function", + "//tensorflow/cc/experimental/base/public:status", + ], +) + +cc_library( + name = "concrete_function_list", + hdrs = [ + "concrete_function_list.h", + ], + deps = [ + ":concrete_function", + "//tensorflow/c/experimental/saved_model/public:concrete_function_list", + ], +) + +cc_library( + name = "function_metadata", + hdrs = [ + "function_metadata.h", + ], + deps = [ + "//tensorflow/c/experimental/saved_model/public:function_metadata", + ], +) + +cc_library( + name = "saved_model_api", + hdrs = [ + "saved_model_api.h", + ], + deps = [ + ":concrete_function", + ":concrete_function_list", + "//tensorflow/c/experimental/saved_model/public:saved_model_api", + "//tensorflow/cc/experimental/base/public:runtime", + "//tensorflow/cc/experimental/base/public:status", + ], +) diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function.h b/tensorflow/cc/saved_model/experimental/public/concrete_function.h new file mode 100644 index 00000000000..f57ba052f1a --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/concrete_function.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h" + +namespace tensorflow { +namespace cc { + +// ConcreteFunction is an executable "function" loaded from a SavedModelAPI. +class ConcreteFunction final { + public: + // TODO(bmzhao): Adding ConcreteFunction::Run in subsequent CL, since + // it depends on tensorflow::cc::Tensor and tensorflow::cc::TensorHandle + + // Returns FunctionMetadata associated with this ConcreteFunction. + const FunctionMetadata* GetFunctionMetadata(); + + private: + friend class SavedModelAPI; + friend class ConcreteFunctionList; + + // TODO(bmzhao): Consider adding a macro for wrapping/unwrapping + // when moving out of experimental. + static ConcreteFunction* wrap(TF_ConcreteFunction* p) { + return reinterpret_cast(p); + } + static TF_ConcreteFunction* unwrap(ConcreteFunction* p) { + return reinterpret_cast(p); + } +}; + +inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() { + return FunctionMetadata::wrap(TF_ConcreteFunctionGetMetadata(unwrap(this))); +} + +} // namespace cc +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h new file mode 100644 index 00000000000..bab95278eac --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" + +namespace tensorflow { +namespace cc { + +// ConcreteFunctionList helps convert an opaque pointer to an array of +// ConcreteFunction pointers to a std::vector. +class ConcreteFunctionList { + public: + // Converts this object to a std::vector + std::vector ToVector(); + + private: + friend class SavedModelAPI; + // Wraps a TF_ConcreteFunctionList. Takes ownership of list. + explicit ConcreteFunctionList(TF_ConcreteFunctionList* list) : list_(list) {} + + struct TFConcreteFunctionListDeleter { + void operator()(TF_ConcreteFunctionList* p) const { + TF_DeleteConcreteFunctionList(p); + } + }; + std::unique_ptr list_; +}; + +inline std::vector ConcreteFunctionList::ToVector() { + int size = TF_ConcreteFunctionListSize(list_.get()); + std::vector result; + result.reserve(size); + for (int i = 0; i < size; ++i) { + result.push_back( + ConcreteFunction::wrap(TF_ConcreteFunctionListGet(list_.get(), i))); + } + return result; +} + +} // namespace cc +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/function_metadata.h b/tensorflow/cc/saved_model/experimental/public/function_metadata.h new file mode 100644 index 00000000000..c3dcc45af0e --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/function_metadata.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" + +namespace tensorflow { +namespace cc { + +// FunctionMetadata stores additional function information, including +// optional signaturedef feeds/fetches (for TF1-based ConcreteFunctions), +// a valid function path (for TF2-based ConcreteFunctions), and +// the types + number of inputs and outputs. +class FunctionMetadata final { + // TODO(bmzhao): Add getters here as necessary. + private: + friend class ConcreteFunction; + static FunctionMetadata* wrap(TF_FunctionMetadata* p) { + return reinterpret_cast(p); + } + static TF_FunctionMetadata* unwrap(FunctionMetadata* p) { + return reinterpret_cast(p); + } +}; + +} // namespace cc +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h new file mode 100644 index 00000000000..814479de213 --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h @@ -0,0 +1,160 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h" + +namespace tensorflow { +namespace cc { + +// SavedModelAPI offers a way to load Tensorflow Saved Models +// (https://www.tensorflow.org/guide/saved_model) and execute saved +// tf.functions or legacy SignatureDefs in a TF2-idiomatic fashion. +// See RFC 207 +// (https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md) +// TODO(bmzhao): Add an e2e example here, once ConcreteFunction::Run is added. +class SavedModelAPI { + public: + // Load a SavedModel from `dirname`. + // + // Params: + // saved_model_path - A directory filepath that the SavedModel is at. + // runtime - A runtime used to load SavedModelAPI. `runtime` must outlive the + // returned TF_SavedModel pointer. + // tags - Optional set of tags. If tags = nullptr, we expect the SavedModel + // to contain a single Metagraph (as for those exported from TF2's + // `tf.saved_model.save`). If tags != nullptr, we load the metagraph + // matching the tags: + // https://github.com/tensorflow/tensorflow/blob/428cdeda09aef81e958eeb274b83d27ad635b57b/tensorflow/core/protobuf/meta_graph.proto#L50-L56 + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. + static std::unique_ptr Load( + const std::string& saved_model_path, const Runtime& runtime, + Status* status, const std::unordered_set* tags = nullptr); + + // Retrieve a function from the TF2 SavedModel via function path. + // + // Params: + // function_path - A string containing the path from the root saved python + // object to a tf.function method. + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer + // is bound to SavedModelAPI it was loaded from. + ConcreteFunction* GetConcreteFunction(const std::string& function_path, + Status* status); + + // Retrieve a function from the TF SavedModel via a SignatureDef key. + // + // Params: + // signature_def_key - String key of SignatureDef map of a SavedModel: + // https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89 + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer + // is bound to SavedModelAPI it was loaded from. + ConcreteFunction* GetSignatureDefFunction(const std::string& function_path, + Status* status); + + // Lists all Conrete Functions available from the SavedModel. + std::vector ListFunctions(); + + // SavedModelAPI is movable, but not copyable. + SavedModelAPI(SavedModelAPI&&) = default; + SavedModelAPI& operator=(SavedModelAPI&&) = default; + + private: + SavedModelAPI(const SavedModelAPI&) = delete; + SavedModelAPI& operator=(const SavedModelAPI&) = delete; + + explicit SavedModelAPI(TF_SavedModel* model) : saved_model_(model) {} + struct TFSavedModelDeleter { + void operator()(TF_SavedModel* p) const { TF_DeleteSavedModel(p); } + }; + std::unique_ptr saved_model_; +}; + +inline std::unique_ptr SavedModelAPI::Load( + const std::string& saved_model_path, const Runtime& runtime, Status* status, + const std::unordered_set* tags) { + TF_SavedModel* saved_model = nullptr; + + if (tags == nullptr) { + saved_model = + TF_LoadSavedModel(saved_model_path.c_str(), runtime.GetTFEContext(), + status->GetTFStatus()); + } else { + std::vector tags_vector; + tags_vector.reserve(tags->size()); + for (const std::string& tag : *tags) { + tags_vector.push_back(tag.c_str()); + } + saved_model = TF_LoadSavedModelWithTags( + saved_model_path.c_str(), runtime.GetTFEContext(), tags_vector.data(), + tags_vector.size(), status->GetTFStatus()); + } + + if (!status->ok()) { + return nullptr; + } + + // We can't use std::make_unique here because of its interaction with a + // private constructor: https://abseil.io/tips/134 + return std::unique_ptr(new SavedModelAPI(saved_model)); +} + +inline ConcreteFunction* SavedModelAPI::GetConcreteFunction( + const std::string& function_path, Status* status) { + TF_ConcreteFunction* function = TF_GetSavedModelConcreteFunction( + saved_model_.get(), function_path.c_str(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + return ConcreteFunction::wrap(function); +} + +inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction( + const std::string& function_path, Status* status) { + TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction( + saved_model_.get(), function_path.c_str(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + return ConcreteFunction::wrap(function); +} + +inline std::vector SavedModelAPI::ListFunctions() { + ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get())); + return list.ToVector(); +} + +} // namespace cc +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ diff --git a/tensorflow/cc/saved_model/experimental/tests/BUILD b/tensorflow/cc/saved_model/experimental/tests/BUILD new file mode 100644 index 00000000000..f24bcfdee2a --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/tests/BUILD @@ -0,0 +1,22 @@ +# Tests for the C++ header-only SavedModelAPI. +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +tf_cc_test( + name = "saved_model_api_test", + srcs = [ + "saved_model_api_test.cc", + ], + deps = [ + "//tensorflow/cc/experimental/base/public:runtime", + "//tensorflow/cc/experimental/base/public:runtime_builder", + "//tensorflow/cc/experimental/base/public:status", + "//tensorflow/cc/saved_model/experimental/public:saved_model_api", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc new file mode 100644 index 00000000000..155c58604bf --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/experimental/public/saved_model_api.h" + +#include +#include +#include + +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/runtime_builder.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +constexpr char kTestData[] = "cc/saved_model/testdata"; + +std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { + return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(), + kTestData, saved_model_dir); +} + +// This value parameterized test allows us to test both TFRT +// and non TFRT runtimes. +// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests +class CPPSavedModelAPITest : public ::testing::TestWithParam {}; + +TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) { + cc::Status status; + cc::RuntimeBuilder builder; + bool use_tfrt = GetParam(); + if (use_tfrt) { + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + builder.SetUseTFRT(use_tfrt); + std::unique_ptr runtime = builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); + std::unordered_set tags = {"serve"}; + std::unique_ptr model = + cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags); + + // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. + // That unblocks writing other tests that require a TF_SavedModel*, + // like loading a ConcreteFunction. This test at least checks that the + // C API builds and can be minimally run. + EXPECT_EQ(status.code(), TF_UNIMPLEMENTED); +} + +TEST_P(CPPSavedModelAPITest, LoadsSavedModel) { + cc::Status status; + cc::RuntimeBuilder builder; + bool use_tfrt = GetParam(); + if (use_tfrt) { + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + builder.SetUseTFRT(use_tfrt); + std::unique_ptr runtime = builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); + std::unique_ptr model = + cc::SavedModelAPI::Load(model_dir, *runtime, &status); + + // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. + // That unblocks writing other tests that require a TF_SavedModel*, + // like loading a ConcreteFunction. This test at least checks that the + // C API builds and can be minimally run. + EXPECT_EQ(status.code(), TF_UNIMPLEMENTED); +} + +INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests, + CPPSavedModelAPITest, ::testing::Bool()); + +} // namespace + +} // namespace tensorflow