Removing LoadSavedModelAPI from AbstractContextInterface. SavedModel is conceptually layered on top of the runtime, as it uses the context to eagerly execute ops (like reloading resources, restoring tensors, etc). This fixes what otherwise would be a circular dependency once we implement SavedModelAPI (context -> SavedModelAPI -> context).
PiperOrigin-RevId: 316169246 Change-Id: I2349803c185f771a65e836273f326bd81622c42b
This commit is contained in:
parent
a4c8a190f8
commit
08420ed0b6
tensorflow
c
eager
experimental/saved_model
core/common_runtime/eager
@ -202,7 +202,6 @@ cc_library(
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
@ -84,11 +84,10 @@ class AbstractContextInterface {
|
||||
// Create an operation to perform op execution
|
||||
virtual AbstractOperationInterface* CreateOperation() = 0;
|
||||
|
||||
// Load a SavedModelAPI object from the given directory and tags
|
||||
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
tensorflow::Status* status) = 0;
|
||||
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
|
||||
// Runtime. This is necessary to decouple runtime-dependent
|
||||
// code that is layered on top of the runtime.
|
||||
virtual bool UsesTFRT() = 0;
|
||||
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
|
@ -57,6 +57,7 @@ cc_library(
|
||||
":concrete_function",
|
||||
":saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -51,7 +52,7 @@ std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
|
||||
Status TFSavedModelAPIImpl::Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out) {
|
||||
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out) {
|
||||
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
|
||||
return errors::Unimplemented(
|
||||
"TFSavedModelAPIImpl loading is unimplemented currently");
|
||||
|
@ -23,14 +23,13 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
public:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
|
||||
Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
@ -40,13 +39,14 @@ class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
static Status Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out);
|
||||
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out);
|
||||
|
||||
std::vector<ConcreteFunction*> ListFunctions() override;
|
||||
|
||||
~TFSavedModelAPIImpl() override = default;
|
||||
|
||||
private:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
std::vector<ConcreteFunction> functions_;
|
||||
};
|
||||
|
||||
|
@ -144,7 +144,9 @@ cc_library(
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -22,11 +22,15 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
extern "C" {
|
||||
@ -34,10 +38,21 @@ extern "C" {
|
||||
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
std::string saved_model_dir(dirname);
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result;
|
||||
|
||||
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"TFRT SavedModel implementation will be added in the future");
|
||||
} else {
|
||||
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
|
||||
status->status = tensorflow::TFSavedModelAPIImpl::Load(
|
||||
dirname, absl::nullopt,
|
||||
tensorflow::down_cast<tensorflow::EagerContext*>(
|
||||
tensorflow::unwrap(ctx)),
|
||||
&saved_model);
|
||||
result = std::move(saved_model);
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -54,9 +69,20 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
tagset.insert(std::string(tags[i]));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
|
||||
&status->status);
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result;
|
||||
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"TFRT SavedModel implementation will be added in the future");
|
||||
} else {
|
||||
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
|
||||
status->status = tensorflow::TFSavedModelAPIImpl::Load(
|
||||
dirname, tagset,
|
||||
tensorflow::down_cast<tensorflow::EagerContext*>(
|
||||
tensorflow::unwrap(ctx)),
|
||||
&saved_model);
|
||||
result = std::move(saved_model);
|
||||
}
|
||||
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -77,8 +77,6 @@ tf_cuda_library(
|
||||
"//tensorflow/c/eager:context_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
] + select({
|
||||
|
@ -32,8 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
|
||||
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
||||
#include "tensorflow/core/common_runtime/colocation_graph.h"
|
||||
@ -192,19 +190,6 @@ AbstractTensorInterface* EagerContext::CreateTensor(
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SavedModelAPI> EagerContext::LoadSavedModelAPI(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
tensorflow::Status* status) {
|
||||
auto result = std::make_unique<TFSavedModelAPIImpl>();
|
||||
auto load_status = TFSavedModelAPIImpl::Load(directory, tags, result.get());
|
||||
if (!load_status.ok()) {
|
||||
status->Update(load_status);
|
||||
result.reset();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
|
||||
const ConfigProto* config, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
@ -599,6 +584,8 @@ std::vector<const FunctionDef*> EagerContext::ListRegisteredFunctions() {
|
||||
|
||||
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
|
||||
|
||||
bool EagerContext::UsesTFRT() { return false; }
|
||||
|
||||
void EagerContext::ListDevices(
|
||||
std::vector<tensorflow::DeviceAttributes>* devices) {
|
||||
local_device_mgr()->ListDeviceAttributes(devices);
|
||||
|
@ -34,7 +34,6 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
@ -186,14 +185,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
Status* status) override;
|
||||
AbstractOperationInterface* CreateOperation() override;
|
||||
|
||||
// Loads a SavedModelAPI from `directory`, with a metagraphdef fitting
|
||||
// the optional "tags". On success status->ok() will be true, and the
|
||||
// returned pointer is non-null. On failure, `status` will be set to
|
||||
// an appropriate error, and nullptr is returned.
|
||||
std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
tensorflow::Status* status) override;
|
||||
bool UsesTFRT() override;
|
||||
|
||||
void ListDevices(std::vector<DeviceAttributes>* devices) override;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user