Refcount the EagerContext.
Currently, the only place we increase the refcount is in remote tensor handle destructors. Unfortunately, since tensorhandle/context lifetime is tied to python objects, it did not seem possible to enforce deletion of the context to be after the deletion of all tensor handles. This also means that if you remote tensors, you will not be able to delete the EagerContext. This is mostly fine, since we don't support deletion of eager contexts in any case, and we only run into this issue during python teardown. PiperOrigin-RevId: 246178464
This commit is contained in:
parent
750ceb3938
commit
eee7e2f0c5
@ -677,7 +677,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
|||||||
|
|
||||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||||
|
|
||||||
LOG_AND_RETURN_IF_ERROR(ctx->context.StoreCollectiveOpsServer(
|
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
|
||||||
std::move(server), grpc_server->worker_env()->device_mgr,
|
std::move(server), grpc_server->worker_env()->device_mgr,
|
||||||
grpc_server->worker_env()->collective_executor_mgr));
|
grpc_server->worker_env()->collective_executor_mgr));
|
||||||
|
|
||||||
|
@ -228,7 +228,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
|
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
|
||||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||||
remote_workers, rendezvous_id, keep_alive_secs, server_def,
|
remote_workers, rendezvous_id, keep_alive_secs, server_def,
|
||||||
remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
|
remote_eager_workers.get(), ctx->context->Async(), &remote_contexts));
|
||||||
|
|
||||||
tensorflow::RemoteRendezvous* r =
|
tensorflow::RemoteRendezvous* r =
|
||||||
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
|
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
|
||||||
@ -247,7 +247,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
|
|
||||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||||
|
|
||||||
return ctx->context.InitializeRemote(
|
return ctx->context->InitializeRemote(
|
||||||
std::move(server), std::move(remote_eager_workers),
|
std::move(server), std::move(remote_eager_workers),
|
||||||
std::move(remote_device_mgr), remote_contexts, r, device_mgr,
|
std::move(remote_device_mgr), remote_contexts, r, device_mgr,
|
||||||
keep_alive_secs);
|
keep_alive_secs);
|
||||||
@ -350,7 +350,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
|
|||||||
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
|
||||||
unsigned char enable,
|
unsigned char enable,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
status->status = ctx->context.SetAsyncForThread(enable);
|
status->status = ctx->context->SetAsyncForThread(enable);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||||
@ -390,15 +390,15 @@ void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
|
|||||||
|
|
||||||
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||||
TF_DeviceList* list = new TF_DeviceList;
|
TF_DeviceList* list = new TF_DeviceList;
|
||||||
ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response);
|
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
|
||||||
if (ctx->context.remote_device_mgr()) {
|
if (ctx->context->remote_device_mgr()) {
|
||||||
ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response);
|
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
|
||||||
}
|
}
|
||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) {
|
void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) {
|
||||||
status->status = ctx->context.ClearCaches();
|
status->status = ctx->context->ClearCaches();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set server_def on the context, possibly updating it.
|
// Set server_def on the context, possibly updating it.
|
||||||
@ -424,7 +424,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
|||||||
|
|
||||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||||
ctx->context.SetThreadLocalDevicePlacementPolicy(
|
ctx->context->SetThreadLocalDevicePlacementPolicy(
|
||||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -434,19 +434,19 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
|||||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||||
TFE_Context* ctx) {
|
TFE_Context* ctx) {
|
||||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||||
ctx->context.GetDevicePlacementPolicy());
|
ctx->context->GetDevicePlacementPolicy());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
|
void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
|
||||||
status->status = ctx->context.AsyncWait();
|
status->status = ctx->context->AsyncWait();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
|
void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
|
||||||
status->status = ctx->context.GetStatus();
|
status->status = ctx->context->GetStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextAsyncClearError(TFE_Context* ctx) {
|
void TFE_ContextAsyncClearError(TFE_Context* ctx) {
|
||||||
ctx->context.ClearAsyncError();
|
ctx->context->ClearAsyncError();
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||||
@ -606,7 +606,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
|||||||
return new TFE_Op(ctx, name, false, types,
|
return new TFE_Op(ctx, name, false, types,
|
||||||
new TFE_OpInferenceContext(op_def));
|
new TFE_OpInferenceContext(op_def));
|
||||||
}
|
}
|
||||||
if (!ctx->context.FindFunctionByName(name)) {
|
if (!ctx->context->FindFunctionByName(name)) {
|
||||||
status->status = tensorflow::errors::NotFound(
|
status->status = tensorflow::errors::NotFound(
|
||||||
"'", name,
|
"'", name,
|
||||||
"' is neither a type of a primitive operation nor a name "
|
"' is neither a type of a primitive operation nor a name "
|
||||||
@ -904,7 +904,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
|||||||
const char* device_name,
|
const char* device_name,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::TensorHandle* handle;
|
tensorflow::TensorHandle* handle;
|
||||||
status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context,
|
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||||
device_name, &handle);
|
device_name, &handle);
|
||||||
if (status->status.ok()) {
|
if (status->status.ok()) {
|
||||||
return new TFE_TensorHandle(handle);
|
return new TFE_TensorHandle(handle);
|
||||||
@ -921,26 +921,26 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
|||||||
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
|
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
status->status = ctx->context.AddFunctionDef(function_def);
|
status->status = ctx->context->AddFunctionDef(function_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
status->status = ctx->context.AddFunctionDef(function->fdef);
|
status->status = ctx->context->AddFunctionDef(function->fdef);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||||
return ctx->context.FindFunctionDef(name) != nullptr;
|
return ctx->context->FindFunctionDef(name) != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||||
ctx->context.SetShouldStoreGraphs(true);
|
ctx->context->SetShouldStoreGraphs(true);
|
||||||
ctx->context.SetShouldStoreStepStats(true);
|
ctx->context->SetShouldStoreStepStats(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||||
ctx->context.SetShouldStoreGraphs(false);
|
ctx->context->SetShouldStoreGraphs(false);
|
||||||
ctx->context.SetShouldStoreStepStats(false);
|
ctx->context->SetShouldStoreStepStats(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
@ -969,9 +969,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
|||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
TFE_ContextAsyncWait(ctx, status);
|
TFE_ContextAsyncWait(ctx, status);
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return;
|
||||||
tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
|
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
|
||||||
status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
|
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
|
||||||
ctx->context.ClearRunMetadata();
|
ctx->context->ClearRunMetadata();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -987,9 +987,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
|
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
|
||||||
|
|
||||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
|
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||||
|
@ -63,7 +63,7 @@ TFE_ProfilerContext* TFE_NewProfilerContext() {
|
|||||||
|
|
||||||
void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context,
|
void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context,
|
||||||
TFE_Context* eager_context) {
|
TFE_Context* eager_context) {
|
||||||
profiler_context->profiler_context.eager_context = &eager_context->context;
|
profiler_context->profiler_context.eager_context = eager_context->context;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) {
|
void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) {
|
||||||
@ -77,11 +77,11 @@ void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||||
ctx->context.SetShouldStoreGraphs(true);
|
ctx->context->SetShouldStoreGraphs(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||||
ctx->context.SetShouldStoreGraphs(false);
|
ctx->context->SetShouldStoreGraphs(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
||||||
|
@ -62,13 +62,16 @@ struct TFE_Context {
|
|||||||
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
|
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||||
tensorflow::Rendezvous* rendezvous,
|
tensorflow::Rendezvous* rendezvous,
|
||||||
const tensorflow::CustomKernelCreator* custom_kernel_creator)
|
const tensorflow::CustomKernelCreator* custom_kernel_creator)
|
||||||
: context(opts,
|
: context(new tensorflow::EagerContext(
|
||||||
|
opts,
|
||||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||||
default_policy),
|
default_policy),
|
||||||
async, device_mgr, device_mgr_owned, rendezvous,
|
async, device_mgr, device_mgr_owned, rendezvous,
|
||||||
custom_kernel_creator) {}
|
custom_kernel_creator)) {}
|
||||||
|
|
||||||
tensorflow::EagerContext context;
|
~TFE_Context() { context->Unref(); }
|
||||||
|
|
||||||
|
tensorflow::EagerContext* context;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_TensorHandle {
|
struct TFE_TensorHandle {
|
||||||
@ -108,7 +111,7 @@ struct TFE_Op {
|
|||||||
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
|
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
|
||||||
const tensorflow::AttrTypeMap* t,
|
const tensorflow::AttrTypeMap* t,
|
||||||
TFE_OpInferenceContext* inference_ctx)
|
TFE_OpInferenceContext* inference_ctx)
|
||||||
: operation(&ctx->context, op, is_function, t),
|
: operation(ctx->context, op, is_function, t),
|
||||||
inference_ctx(inference_ctx) {}
|
inference_ctx(inference_ctx) {}
|
||||||
|
|
||||||
tensorflow::EagerOperation operation;
|
tensorflow::EagerOperation operation;
|
||||||
|
@ -14,10 +14,11 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
|
||||||
|
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
@ -297,6 +298,61 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
|||||||
TestRemoteExecuteSilentCopies(true);
|
TestRemoteExecuteSilentCopies(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server->Start().ok());
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||||
|
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||||
|
const char remote_device_name[] =
|
||||||
|
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
|
auto* h0_task1 =
|
||||||
|
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(h0_task0);
|
||||||
|
|
||||||
|
TFE_ContextAsyncWait(ctx, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
|
||||||
|
// Delete tensors after context is deleted.
|
||||||
|
TFE_DeleteTensorHandle(h0_task1);
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
// TODO(nareshmodi): Figure out how to correctly shut the server down.
|
||||||
|
worker_server.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) {
|
||||||
|
TestRemoteExecuteDeleteTensorAfterContext(false);
|
||||||
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) {
|
||||||
|
TestRemoteExecuteDeleteTensorAfterContext(true);
|
||||||
|
}
|
||||||
|
|
||||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||||
const std::vector<float>& expected_values) {
|
const std::vector<float>& expected_values) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
@ -76,7 +76,7 @@ class RunMetadataListener {
|
|||||||
virtual void BeforeClearRunMetadata() = 0;
|
virtual void BeforeClearRunMetadata() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class EagerContext {
|
class EagerContext : public core::RefCounted {
|
||||||
public:
|
public:
|
||||||
// TODO: remove this constructor once we migrate all callers to the next one.
|
// TODO: remove this constructor once we migrate all callers to the next one.
|
||||||
EagerContext(const SessionOptions& opts,
|
EagerContext(const SessionOptions& opts,
|
||||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
@ -683,7 +684,9 @@ Status EagerLocalExecute(EagerOperation* op,
|
|||||||
std::function<void()> GetRemoteTensorDestructor(
|
std::function<void()> GetRemoteTensorDestructor(
|
||||||
EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id,
|
EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id,
|
||||||
uint64 op_id, int output_num) {
|
uint64 op_id, int output_num) {
|
||||||
|
ctx->Ref();
|
||||||
return [ctx, eager_client, context_id, op_id, output_num]() {
|
return [ctx, eager_client, context_id, op_id, output_num]() {
|
||||||
|
auto cleanup = gtl::MakeCleanup([ctx]() { ctx->Unref(); });
|
||||||
if (!ctx->HasActiveRemoteContext(context_id)) {
|
if (!ctx->HasActiveRemoteContext(context_id)) {
|
||||||
// This means that this tensor was pointing to a remote device, which
|
// This means that this tensor was pointing to a remote device, which
|
||||||
// has been changed out from under us. Simply return since there is
|
// has been changed out from under us. Simply return since there is
|
||||||
|
@ -104,10 +104,10 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
|||||||
// Initialize remote tensor communication based on worker session.
|
// Initialize remote tensor communication based on worker session.
|
||||||
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext(
|
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
||||||
SessionOptions(),
|
SessionOptions(),
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
request->async(), device_mgr, false, r, nullptr));
|
request->async(), device_mgr, false, r, nullptr);
|
||||||
|
|
||||||
std::vector<DeviceAttributes> device_attributes;
|
std::vector<DeviceAttributes> device_attributes;
|
||||||
device_mgr->ListDeviceAttributes(&device_attributes);
|
device_mgr->ListDeviceAttributes(&device_attributes);
|
||||||
@ -122,9 +122,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
|||||||
do {
|
do {
|
||||||
context_id = random::New64();
|
context_id = random::New64();
|
||||||
} while (contexts_.find(context_id) != contexts_.end());
|
} while (contexts_.find(context_id) != contexts_.end());
|
||||||
contexts_.emplace(
|
contexts_.emplace(context_id,
|
||||||
context_id,
|
new ServerContext(ctx, request->keep_alive_secs(), env_));
|
||||||
new ServerContext(std::move(ctx), request->keep_alive_secs(), env_));
|
|
||||||
}
|
}
|
||||||
response->set_context_id(context_id);
|
response->set_context_id(context_id);
|
||||||
|
|
||||||
|
@ -104,9 +104,9 @@ class EagerServiceImpl {
|
|||||||
// and the EagerContext).
|
// and the EagerContext).
|
||||||
class ServerContext : public core::RefCounted {
|
class ServerContext : public core::RefCounted {
|
||||||
public:
|
public:
|
||||||
explicit ServerContext(std::unique_ptr<tensorflow::EagerContext> ctx,
|
explicit ServerContext(tensorflow::EagerContext* ctx,
|
||||||
int64 destroy_after_secs, const WorkerEnv* env)
|
int64 destroy_after_secs, const WorkerEnv* env)
|
||||||
: ctx_(std::move(ctx)), env_(env) {
|
: ctx_(ctx), env_(env) {
|
||||||
destroy_after_micros_ =
|
destroy_after_micros_ =
|
||||||
destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros;
|
destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros;
|
||||||
RecordAccess();
|
RecordAccess();
|
||||||
@ -115,9 +115,11 @@ class EagerServiceImpl {
|
|||||||
for (const auto& entry : tensors_) {
|
for (const auto& entry : tensors_) {
|
||||||
entry.second->Unref();
|
entry.second->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::EagerContext* Context() const { return ctx_.get(); }
|
tensorflow::EagerContext* Context() const { return ctx_; }
|
||||||
|
|
||||||
void AddOperationOutputs(
|
void AddOperationOutputs(
|
||||||
const gtl::ArraySlice<tensorflow::TensorHandle*>& handles,
|
const gtl::ArraySlice<tensorflow::TensorHandle*>& handles,
|
||||||
@ -179,7 +181,7 @@ class EagerServiceImpl {
|
|||||||
RemoteTensorHandleInternalEquals>;
|
RemoteTensorHandleInternalEquals>;
|
||||||
|
|
||||||
// The context for this execution.
|
// The context for this execution.
|
||||||
std::unique_ptr<tensorflow::EagerContext> ctx_;
|
tensorflow::EagerContext* ctx_;
|
||||||
|
|
||||||
// The state related to the context for this execution.
|
// The state related to the context for this execution.
|
||||||
mutex tensors_mu_;
|
mutex tensors_mu_;
|
||||||
|
@ -22,7 +22,9 @@ namespace tflite {
|
|||||||
namespace flex {
|
namespace flex {
|
||||||
DelegateData::DelegateData() {}
|
DelegateData::DelegateData() {}
|
||||||
|
|
||||||
DelegateData::~DelegateData() {}
|
DelegateData::~DelegateData() {
|
||||||
|
if (eager_context_) eager_context_->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::Status DelegateData::Prepare(
|
tensorflow::Status DelegateData::Prepare(
|
||||||
const tensorflow::SessionOptions& session_options) {
|
const tensorflow::SessionOptions& session_options) {
|
||||||
@ -40,10 +42,10 @@ tensorflow::Status DelegateData::Prepare(
|
|||||||
// Note that Rendezvous is ref-counted so it will be automatically deleted.
|
// Note that Rendezvous is ref-counted so it will be automatically deleted.
|
||||||
tensorflow::Rendezvous* rendezvous =
|
tensorflow::Rendezvous* rendezvous =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||||
eager_context_.reset(new tensorflow::EagerContext(
|
eager_context_ = new tensorflow::EagerContext(
|
||||||
session_options,
|
session_options,
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
/*async=*/false, std::move(device_mgr), rendezvous));
|
/*async=*/false, std::move(device_mgr), rendezvous);
|
||||||
return tensorflow::Status();
|
return tensorflow::Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ class DelegateData {
|
|||||||
|
|
||||||
// The EagerContext that is required for execution of Flex Ops.
|
// The EagerContext that is required for execution of Flex Ops.
|
||||||
// Note: The context is lazily created after the first call to |Prepare()|.
|
// Note: The context is lazily created after the first call to |Prepare()|.
|
||||||
tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
|
tensorflow::EagerContext* GetEagerContext() { return eager_context_; }
|
||||||
|
|
||||||
// Map from TF Lite tensor index to TensorFlow tensor for a given context.
|
// Map from TF Lite tensor index to TensorFlow tensor for a given context.
|
||||||
BufferMap* GetBufferMap(const TfLiteContext* context) {
|
BufferMap* GetBufferMap(const TfLiteContext* context) {
|
||||||
@ -48,7 +48,7 @@ class DelegateData {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// Will be null until Prepare() is called and completes successfully.
|
// Will be null until Prepare() is called and completes successfully.
|
||||||
std::unique_ptr<tensorflow::EagerContext> eager_context_;
|
tensorflow::EagerContext* eager_context_ = nullptr;
|
||||||
// TODO(b/112439500): Clean up stale BufferMap instances after adding the
|
// TODO(b/112439500): Clean up stale BufferMap instances after adding the
|
||||||
// necessary cleanup hook from a TfLiteContext to a TfLiteDelegate.
|
// necessary cleanup hook from a TfLiteContext to a TfLiteDelegate.
|
||||||
std::unordered_map<const TfLiteContext*, BufferMap> buffer_map_;
|
std::unordered_map<const TfLiteContext*, BufferMap> buffer_map_;
|
||||||
|
Loading…
Reference in New Issue
Block a user