Add AbstractContextInterface to abstract runtime

We want to be able to switch between different runtime backends. In
order to do that we introduce the AbstractContextInterface for different
runtimes to implement. This interface handles the creation of keyobjects
such as Tensors, TensorHandles, Operations as well as device management.

PiperOrigin-RevId: 303251247
Change-Id: Ib37c5c7bb3c49d418ad6d6d71fb70f6c2063f569
This commit is contained in:
Gaurav Jain 2020-03-26 20:41:15 -07:00 committed by TensorFlower Gardener
parent 2c3284560d
commit a31ffb72e1
14 changed files with 556 additions and 106 deletions

View File

@ -703,7 +703,8 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
} while (0); } while (0);
// New server created for new server_def. Unused if updating server_def. // New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context = ctx->context; tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::GrpcServer* grpc_server = tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer()); dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) { if (grpc_server == nullptr) {

View File

@ -28,6 +28,8 @@ tf_cuda_library(
"c_api_debug.cc", "c_api_debug.cc",
"c_api_experimental.h", "c_api_experimental.h",
"c_api_internal.h", "c_api_internal.h",
"context_interface.cc",
"context_interface.h",
"operation_interface.cc", "operation_interface.cc",
"operation_interface.h", "operation_interface.h",
"tensor_handle_interface.h", "tensor_handle_interface.h",
@ -95,6 +97,7 @@ filegroup(
srcs = [ srcs = [
"c_api_experimental.h", "c_api_experimental.h",
"c_api_internal.h", "c_api_internal.h",
"context_interface.h",
"dlpack.h", "dlpack.h",
"operation_interface.h", "operation_interface.h",
"tensor_handle_interface.h", "tensor_handle_interface.h",
@ -109,6 +112,7 @@ tf_cuda_library(
name = "c_api_internal", name = "c_api_internal",
srcs = [ srcs = [
"c_api_experimental.h", "c_api_experimental.h",
"context_interface.h",
"operation_interface.h", "operation_interface.h",
"tensor_handle_interface.h", "tensor_handle_interface.h",
], ],

View File

@ -305,7 +305,9 @@ tensorflow::Status CreateRemoteContexts(
server_def.default_session_config()); server_def.default_session_config());
std::vector<bool> filtered_device_mask; std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers( tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(), remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask); &filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(), DCHECK_EQ(filtered_device_mask.size(),
@ -388,7 +390,9 @@ tensorflow::Status UpdateRemoteContexts(
} }
std::vector<bool> filtered_device_mask; std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers( tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(), remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask); &filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count); DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
@ -467,7 +471,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// New server created for new server_def. Unused if updating server_def. // New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server; std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context = ctx->context; tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::GrpcServer* grpc_server; tensorflow::GrpcServer* grpc_server;
if (reset_context) { if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -696,14 +701,16 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get()); new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context{new tensorflow::EagerContext( return new TFE_Context{std::make_unique<tensorflow::ContextInterface>(
new tensorflow::EagerContext(
opts->session_options.options, opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>( static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy), opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy), static_cast<tensorflow::ContextMirroringPolicy>(
opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r, /*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator())}; tensorflow::GetDefaultCustomKernelCreator()))};
} }
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -714,20 +721,24 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr); new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context{new tensorflow::EagerContext( return new TFE_Context{std::make_unique<tensorflow::ContextInterface>(
new tensorflow::EagerContext(
opts->session_options.options, opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>( static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy), opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy), static_cast<tensorflow::ContextMirroringPolicy>(
opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr, opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r, /*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator())}; tensorflow::GetDefaultCustomKernelCreator()))};
} }
void TFE_DeleteContext(TFE_Context* ctx) { void TFE_DeleteContext(TFE_Context* ctx) {
// context->RefCountIsOne() should be true here. // context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting. // TODO(iga): Remove EagerContext refcounting.
ctx->context->Unref(); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->Unref();
delete ctx; delete ctx;
} }
@ -739,7 +750,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
} }
void TFE_ContextClearCaches(TFE_Context* ctx) { void TFE_ContextClearCaches(TFE_Context* ctx) {
ctx->context->ClearCachesAndThreadExecutors(); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->ClearCachesAndThreadExecutors();
} }
// Set server_def on the context, possibly updating it. // Set server_def on the context, possibly updating it.
@ -769,8 +782,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
device_filters[i] = tdf.second.device_filters(i); device_filters[i] = tdf.second.device_filters(i);
} }
const string remote_worker = remote_prefix + std::to_string(task_index); const string remote_worker = remote_prefix + std::to_string(task_index);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = status->status =
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters); context->SetRemoteDeviceFilters(remote_worker, device_filters);
} }
} }
} }
@ -789,11 +804,13 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
"TFE_ContextSetServerDef not supported on mobile"); "TFE_ContextSetServerDef not supported on mobile");
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def; tensorflow::ServerDef server_def;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
if (!server_def.ParseFromArray(proto, proto_len)) { if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer"); "Invalid tensorflow.ServerDef protocol buffer");
return; return;
} else if (ctx->context->GetContextId() == } else if (context->GetContextId() ==
tensorflow::EagerContext::kInvalidContextId) { tensorflow::EagerContext::kInvalidContextId) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Trying to update a context with invalid context id."); "Trying to update a context with invalid context id.");
@ -817,7 +834,8 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
"TFE_ContextSetServerDef not supported on mobile"); "TFE_ContextSetServerDef not supported on mobile");
return false; return false;
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = ctx->context; tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::GrpcServer* grpc_server = tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(context->GetServer()); static_cast<tensorflow::GrpcServer*>(context->GetServer());
@ -872,13 +890,17 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
#if defined(IS_MOBILE_PLATFORM) #if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK(); status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
status->status = ctx->context->SyncExecutors(); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->SyncExecutors();
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
} }
void TFE_ContextSetThreadLocalDevicePlacementPolicy( void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context->SetThreadLocalDevicePlacementPolicy( tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
} }
@ -887,8 +909,10 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
// safe to call this function from the async EagerExecutor threads. // safe to call this function from the async EagerExecutor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) { TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
return static_cast<TFE_ContextDevicePlacementPolicy>( return static_cast<TFE_ContextDevicePlacementPolicy>(
ctx->context->GetDevicePlacementPolicy()); context->GetDevicePlacementPolicy());
} }
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
@ -1178,7 +1202,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg), void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status) { void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device = nullptr; tensorflow::Device* device = nullptr;
tensorflow::EagerContext* context = ctx->context; tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->FindDeviceFromName(device_name, &device); status->status = context->FindDeviceFromName(device_name, &device);
tensorflow::CustomDevice* custom_device = nullptr; tensorflow::CustomDevice* custom_device = nullptr;
if (!status->status.ok()) { if (!status->status.ok()) {
@ -1248,8 +1273,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) { TF_Status* status) {
std::unique_ptr<TFE_Op> new_op( std::unique_ptr<TFE_Op> new_op(new TFE_Op{ctx->context->CreateOperation()});
new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)});
status->status = new_op->operation->Reset(op_or_function_name, nullptr); status->status = new_op->operation->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) { if (!status->status.ok()) {
new_op.reset(); new_op.reset();
@ -1497,7 +1521,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
tensorflow::TensorHandle* handle = nullptr; tensorflow::TensorHandle* handle = nullptr;
tensorflow::Device* device; tensorflow::Device* device;
tensorflow::EagerContext* context = ctx->context; tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->FindDeviceFromName(device_name, &device); status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) { if (!status->status.ok()) {
tensorflow::CustomDevice* dev; tensorflow::CustomDevice* dev;
@ -1556,29 +1581,41 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return; return;
} }
status->status = ctx->context->AddFunctionDef(function_def); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->AddFunctionDef(function_def);
} }
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) { TF_Status* status) {
status->status = ctx->context->AddFunctionDef(function->fdef); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->AddFunctionDef(function->fdef);
} }
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) { TF_Status* status) {
status->status = ctx->context->RemoveFunction(name); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->RemoveFunction(name);
} }
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
return ctx->context->FindFunctionDef(name) != nullptr; tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
return context->FindFunctionDef(name) != nullptr;
} }
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(false); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetShouldStoreGraphs(false);
} }
} // extern "C" } // extern "C"
@ -1590,7 +1627,8 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = ctx->context; tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = context->Executor().WaitForAllPendingNodes(); status->status = context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return; if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*context->MetadataMu()); tensorflow::mutex_lock ml(*context->MetadataMu());
@ -1611,9 +1649,17 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
} }
} // namespace } // namespace
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); } void TFE_ContextStartStep(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->StartStep();
}
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); } void TFE_ContextEndStep(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->EndStep();
}
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) { void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>( auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
@ -1793,6 +1839,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
TF_Status* status) { TF_Status* status) {
auto custom_device = auto custom_device =
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name); std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
status->status = status->status =
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device)); context->RegisterCustomDevice(device_name, std::move(custom_device));
} }

View File

@ -40,11 +40,15 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
} }
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(false); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetShouldStoreGraphs(false);
} }
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell, void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
@ -474,7 +478,9 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
void TFE_ContextSetThreadLocalMirroringPolicy( void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) { TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
ctx->context->SetThreadLocalMirroringPolicy( tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetThreadLocalMirroringPolicy(
static_cast<tensorflow::ContextMirroringPolicy>(policy)); static_cast<tensorflow::ContextMirroringPolicy>(policy));
} }
@ -483,8 +489,9 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
// safe to call this function from the async EagerExecutor threads. // safe to call this function from the async EagerExecutor threads.
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context* ctx) { TFE_Context* ctx) {
return static_cast<TFE_ContextMirroringPolicy>( tensorflow::EagerContext* context =
ctx->context->GetMirroringPolicy()); tensorflow::ContextFromInterface(ctx->context);
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
} }
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options, void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
@ -537,16 +544,22 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
} }
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
ctx->context->SetExecutorForThread(executor->executor()); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
context->SetExecutorForThread(executor->executor());
} }
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(&ctx->context->Executor()); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
return new TFE_Executor(&context->Executor());
} }
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
auto address_space = tensorflow::DeviceNameUtils::AddressSpace( auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
ctx->context->HostCPU()->parsed_name()); context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length()); void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0); str.copy(static_cast<char*>(data), str.length(), 0);
@ -565,7 +578,9 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) { TF_Buffer* buf, TF_Status* status) {
auto* function_def = ctx->context->FindFunctionDef(function_name); tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
auto* function_def = context->FindFunctionDef(function_name);
if (function_def == nullptr) { if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound( status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name); "Unable to find FunctionDef with name: ", function_name);

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
@ -62,7 +63,7 @@ struct TFE_ContextOptions {
}; };
struct TFE_Context { struct TFE_Context {
tensorflow::EagerContext* context; std::unique_ptr<tensorflow::AbstractContextInterface> context;
}; };
struct TFE_TensorHandle { struct TFE_TensorHandle {

View File

@ -0,0 +1,150 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/platform/casts.h"
namespace tensorflow {
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt64Scalar(
int64 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateUint64Scalar(
uint64 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt32Scalar(
int32 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateFloatScalar(
float value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateDoubleScalar(
double value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateHalfScalar(
Eigen::half value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateStringScalar(
tstring value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface>
ContextInterface::CreateComplex128Scalar(complex128 value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolScalar(
bool value) {
return std::make_unique<TensorInterface>(Tensor(value));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt64Tensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_INT64, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateUint64Tensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_UINT64, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt32Tensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_INT32, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateFloatTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_FLOAT, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateDoubleTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_DOUBLE, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateHalfTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_HALF, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateStringTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_STRING, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface>
ContextInterface::CreateComplex128Tensor(absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_COMPLEX128, TensorShape(dim_sizes)));
}
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolTensor(
absl::Span<const int64> dim_sizes) {
return std::make_unique<TensorInterface>(
Tensor(DT_BOOL, TensorShape(dim_sizes)));
}
Status ContextInterface::CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t,
std::unique_ptr<AbstractTensorHandleInterface>* h) {
Tensor tensor = tensorflow::down_cast<TensorInterface*>(t.get())->Tensor();
tensorflow::TensorHandle* handle = nullptr;
auto status =
TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/ctx_->HostCPU(),
/*op_device=*/nullptr, ctx_, &handle);
if (!status.ok()) {
return status;
}
*h = std::make_unique<TensorHandleInterface>(handle);
return status;
}
std::unique_ptr<AbstractOperationInterface>
ContextInterface::CreateOperation() {
return std::make_unique<tensorflow::OperationInterface>(ctx_);
}
void ContextInterface::ListDevices(
std::vector<tensorflow::DeviceAttributes>* devices) {
ctx_->ListDevices(devices);
}
} // namespace tensorflow

View File

@ -0,0 +1,157 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#include <memory>
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/tstring.h"
namespace tensorflow {
// Abstract interface to a context.
//
// A context is responsible for creating key objects such as Tensors,
// TensorHandles & Operations.
class AbstractContextInterface {
public:
virtual ~AbstractContextInterface() {}
// Scalar creation functions
virtual std::unique_ptr<AbstractTensorInterface> CreateInt64Scalar(
int64 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateUint64Scalar(
uint64 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateInt32Scalar(
int32 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateFloatScalar(
float value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateDoubleScalar(
double value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateHalfScalar(
Eigen::half value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateStringScalar(
tstring value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateComplex128Scalar(
complex128 value) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateBoolScalar(
bool value) = 0;
// Tensor creation functions
virtual std::unique_ptr<AbstractTensorInterface> CreateInt64Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateUint64Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateInt32Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateFloatTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateDoubleTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateHalfTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateStringTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
absl::Span<const int64> dim_sizes) = 0;
// Create a handle to wrap and manage a Tensor
virtual tensorflow::Status CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t,
std::unique_ptr<AbstractTensorHandleInterface>* handle) = 0;
// Create an operation to perform op execution
virtual std::unique_ptr<AbstractOperationInterface> CreateOperation() = 0;
// List attributes of available devices
virtual void ListDevices(
std::vector<tensorflow::DeviceAttributes>* devices) = 0;
};
// TODO(gjn): Try to move these all to EagerContext and make it implement
// AbstractContextInterface. Currently, this is not so straightforward because
// of various BUILD file dependencies.
class ContextInterface : public AbstractContextInterface {
public:
explicit ContextInterface(EagerContext* ctx) : ctx_(ctx) {}
~ContextInterface() override {}
std::unique_ptr<AbstractTensorInterface> CreateInt64Scalar(
int64 value) override;
std::unique_ptr<AbstractTensorInterface> CreateUint64Scalar(
uint64 value) override;
std::unique_ptr<AbstractTensorInterface> CreateInt32Scalar(
int32 value) override;
std::unique_ptr<AbstractTensorInterface> CreateFloatScalar(
float value) override;
std::unique_ptr<AbstractTensorInterface> CreateDoubleScalar(
double value) override;
std::unique_ptr<AbstractTensorInterface> CreateHalfScalar(
Eigen::half value) override;
std::unique_ptr<AbstractTensorInterface> CreateStringScalar(
tensorflow::tstring value) override;
std::unique_ptr<AbstractTensorInterface> CreateComplex128Scalar(
tensorflow::complex128 value) override;
std::unique_ptr<AbstractTensorInterface> CreateBoolScalar(
bool value) override;
std::unique_ptr<AbstractTensorInterface> CreateInt64Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateUint64Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateInt32Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateFloatTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateDoubleTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateHalfTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateStringTensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) override;
std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
absl::Span<const int64> dim_sizes) override;
tensorflow::Status CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t,
std::unique_ptr<AbstractTensorHandleInterface>* h) override;
std::unique_ptr<AbstractOperationInterface> CreateOperation() override;
void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices) override;
// For runtime specific APIs, provide ability to get the underlying context.
EagerContext* Context() const { return ctx_; }
private:
EagerContext* ctx_;
};
inline EagerContext* ContextFromInterface(
const std::unique_ptr<AbstractContextInterface>& context) {
return down_cast<tensorflow::ContextInterface*>(context.get())->Context();
}
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_

View File

@ -26,8 +26,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
OperationInterface::OperationInterface(TFE_Context* ctx) OperationInterface::OperationInterface(EagerContext* ctx) : operation_(ctx) {}
: operation_(ctx->context) {}
const string& OperationInterface::DeviceName() const { const string& OperationInterface::DeviceName() const {
absl::variant<Device*, CustomDevice*> variant_device = absl::variant<Device*, CustomDevice*> variant_device =

View File

@ -112,7 +112,7 @@ class OpDef;
class OperationInterface : public AbstractOperationInterface { class OperationInterface : public AbstractOperationInterface {
public: public:
explicit OperationInterface(TFE_Context* ctx); explicit OperationInterface(EagerContext* ctx);
~OperationInterface() override{}; ~OperationInterface() override{};
void Clear() override { operation_.Clear(); } void Clear() override { operation_.Clear(); }

View File

@ -381,7 +381,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
->ToTensor(dst); ->ToTensor(dst);
} }
Status TensorInterface::ToTensor(Tensor* dst) const { Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const {
if (tensor_.dtype() == DT_RESOURCE) { if (tensor_.dtype() == DT_RESOURCE) {
if (tensor_.dims() != 0) { if (tensor_.dims() != 0) {
return InvalidArgument( return InvalidArgument(
@ -389,7 +389,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const {
"shape ", "shape ",
tensor_.shape().DebugString()); tensor_.shape().DebugString());
} }
*dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape()); *dst = tensorflow::Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString( if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
string(static_cast<const char*>(Data()), ByteSize()))) { string(static_cast<const char*>(Data()), ByteSize()))) {
return InvalidArgument( return InvalidArgument(
@ -414,7 +414,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const {
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
const char* limit = input + src_size; const char* limit = input + src_size;
*dst = Tensor(tensor_.dtype(), tensor_.shape()); *dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape());
auto dstarray = dst->flat<tstring>(); auto dstarray = dst->flat<tstring>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) { for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset = tensorflow::uint64 offset =

View File

@ -54,7 +54,7 @@ namespace tensorflow {
class TensorInterface : public AbstractTensorInterface { class TensorInterface : public AbstractTensorInterface {
public: public:
TensorInterface() {} TensorInterface() {}
explicit TensorInterface(Tensor t) : tensor_(std::move(t)) {} explicit TensorInterface(tensorflow::Tensor t) : tensor_(std::move(t)) {}
~TensorInterface() override {} ~TensorInterface() override {}
TF_DataType Type() const override; TF_DataType Type() const override;
@ -66,12 +66,16 @@ class TensorInterface : public AbstractTensorInterface {
bool IsAligned() const override; bool IsAligned() const override;
bool CanMove() const override; bool CanMove() const override;
Status ToTensor(Tensor* dst) const; Status ToTensor(tensorflow::Tensor* dst) const;
Status BitcastFrom(const TensorInterface& from, TF_DataType type, Status BitcastFrom(const TensorInterface& from, TF_DataType type,
const int64_t* new_dims, int num_new_dims); const int64_t* new_dims, int num_new_dims);
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
tensorflow::Tensor Tensor() { return tensor_; }
private: private:
Tensor tensor_; tensorflow::Tensor tensor_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -75,7 +75,7 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) { const char* raw_device_name, TF_Status* status) {
std::unique_ptr<TFE_Op> op = ReleaseThreadLocalOp(ctx); std::unique_ptr<TFE_Op> op = ReleaseThreadLocalOp(ctx);
if (!op) { if (!op) {
op.reset(new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)}); op.reset(new TFE_Op{ctx->context->CreateOperation()});
} }
status->status = op->operation->Reset(op_or_function_name, raw_device_name); status->status = op->operation->Reset(op_or_function_name, raw_device_name);
if (!status->status.ok()) { if (!status->status.ok()) {

View File

@ -191,18 +191,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
// Prepare the argument. // Prepare the argument.
PyObject* args = nullptr; PyObject* args = nullptr;
TFE_Context* ctx = nullptr;
std::unique_ptr<EagerExecutor> new_executor = nullptr; std::unique_ptr<EagerExecutor> new_executor = nullptr;
EagerExecutor* old_executor = nullptr; EagerExecutor* old_executor = nullptr;
if (call->eager) { if (call->eager) {
// See FuncRegistry._ctx. // See FuncRegistry._ctx.
ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer( TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
CHECK_NE(ctx, nullptr); CHECK_NE(ctx, nullptr);
TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx->context, &args)); EagerContext* context = ContextFromInterface(ctx->context);
TF_RETURN_IF_ERROR(MakeArgTuple(call, context, &args));
new_executor.reset(new EagerExecutor(call->eager_async)); new_executor.reset(new EagerExecutor(call->eager_async));
old_executor = &ctx->context->Executor(); old_executor = &context->Executor();
ctx->context->SetExecutorForThread(new_executor.get()); context->SetExecutorForThread(new_executor.get());
} else { } else {
TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args)); TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args));
} }
@ -236,8 +236,11 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
} }
if (new_executor != nullptr) { if (new_executor != nullptr) {
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
EagerContext* context = ContextFromInterface(ctx->context);
s.Update(new_executor->WaitForAllPendingNodes()); s.Update(new_executor->WaitForAllPendingNodes());
ctx->context->SetExecutorForThread(old_executor); context->SetExecutorForThread(old_executor);
} }
TF_RETURN_IF_ERROR(s); TF_RETURN_IF_ERROR(s);

View File

@ -17,12 +17,12 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h" #include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
@ -278,29 +278,26 @@ struct Converter {
static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state, static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state,
TFE_TensorHandle** h, const char** error) { TFE_TensorHandle** h, const char** error) {
// TODO(josh11b): Allocator & attributes // TODO(josh11b): Allocator & attributes
// TODO(gjn): Use optimized scalar constructors when possible. std::unique_ptr<AbstractTensorInterface> tensor;
Tensor result(ConverterTraits<T>::kTypeEnum,
TensorShape(state->inferred_shape));
if (state->inferred_shape.empty()) { /* Scalar case */ if (state->inferred_shape.empty()) { /* Scalar case */
T value; T value;
auto scalar = ZeroDimArrayToScalar(obj, state); auto scalar = ZeroDimArrayToScalar(obj, state);
*error = ConverterTraits<T>::ConvertScalar(scalar, &value); *error = ConverterTraits<T>::ConvertScalar(scalar, &value);
Py_DECREF(scalar); Py_DECREF(scalar);
if (*error != nullptr) return errors::InvalidArgument(*error); if (*error != nullptr) return errors::InvalidArgument(*error);
result.scalar<T>()() = value; tensor = ConverterTraits<T>::CreateScalar(ctx, value);
} else { } else {
T* buf = result.flat<T>().data(); tensor = ConverterTraits<T>::CreateTensor(ctx, state->inferred_shape);
if (tensor->NumElements() > 0) {
T* buf = static_cast<T*>(tensor->Data());
*error = Helper(obj, 0, state, &buf); *error = Helper(obj, 0, state, &buf);
if (*error != nullptr) return errors::InvalidArgument(*error); if (*error != nullptr) return errors::InvalidArgument(*error);
} }
tensorflow::TensorHandle* handle = nullptr;
auto status = tensorflow::TensorHandle::CreateLocalHandle(
std::move(result), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
ctx->context, &handle);
if (!status.ok()) {
return status;
} }
*h = new TFE_TensorHandle{std::make_unique<TensorHandleInterface>(handle)}; std::unique_ptr<AbstractTensorHandleInterface> handle;
TF_RETURN_IF_ERROR(
ctx->context->CreateLocalHandle(std::move(tensor), &handle));
*h = new TFE_TensorHandle{std::move(handle)};
return Status::OK(); return Status::OK();
} }
}; };
@ -309,7 +306,15 @@ struct Converter {
template <> template <>
struct ConverterTraits<int64> { struct ConverterTraits<int64> {
static const tensorflow::DataType kTypeEnum = DT_INT64; static std::unique_ptr<AbstractTensorInterface> CreateScalar(TFE_Context* ctx,
int64 value) {
return ctx->context->CreateInt64Scalar(value);
}
static std::unique_ptr<AbstractTensorInterface> CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateInt64Tensor(dim_sizes);
}
static const char* ConvertScalar(PyObject* v, int64* out) { static const char* ConvertScalar(PyObject* v, int64* out) {
#if PY_MAJOR_VERSION < 3 #if PY_MAJOR_VERSION < 3
@ -342,7 +347,15 @@ typedef Converter<int64> Int64Converter;
template <> template <>
struct ConverterTraits<uint64> { struct ConverterTraits<uint64> {
static const tensorflow::DataType kTypeEnum = DT_UINT64; static std::unique_ptr<AbstractTensorInterface> CreateScalar(TFE_Context* ctx,
uint64 value) {
return ctx->context->CreateUint64Scalar(value);
}
static std::unique_ptr<AbstractTensorInterface> CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateUint64Tensor(dim_sizes);
}
static const char* ConvertScalar(PyObject* v, uint64* out) { static const char* ConvertScalar(PyObject* v, uint64* out) {
#if PY_MAJOR_VERSION < 3 #if PY_MAJOR_VERSION < 3
@ -372,7 +385,15 @@ typedef Converter<uint64> UInt64Converter;
template <> template <>
struct ConverterTraits<int32> { struct ConverterTraits<int32> {
static const tensorflow::DataType kTypeEnum = DT_INT32; static std::unique_ptr<AbstractTensorInterface> CreateScalar(TFE_Context* ctx,
int32 value) {
return ctx->context->CreateInt32Scalar(value);
}
static std::unique_ptr<AbstractTensorInterface> CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateInt32Tensor(dim_sizes);
}
static const char* ConvertScalar(PyObject* v, int32* out) { static const char* ConvertScalar(PyObject* v, int32* out) {
int64 i; int64 i;
@ -472,7 +493,16 @@ static const char* ConvertOneFloat(PyObject* v, T* out) {
template <> template <>
struct ConverterTraits<float> { struct ConverterTraits<float> {
static const tensorflow::DataType kTypeEnum = DT_FLOAT; static std::unique_ptr<AbstractTensorInterface> CreateScalar(TFE_Context* ctx,
float value) {
return ctx->context->CreateFloatScalar(value);
}
static std::unique_ptr<AbstractTensorInterface> CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateFloatTensor(dim_sizes);
}
static const char* ConvertScalar(PyObject* v, float* out) { static const char* ConvertScalar(PyObject* v, float* out) {
return ConvertOneFloat<float>(v, out); return ConvertOneFloat<float>(v, out);
} }
@ -480,7 +510,16 @@ struct ConverterTraits<float> {
template <> template <>
struct ConverterTraits<double> { struct ConverterTraits<double> {
static const tensorflow::DataType kTypeEnum = DT_DOUBLE; static std::unique_ptr<AbstractTensorInterface> CreateScalar(TFE_Context* ctx,
double value) {
return ctx->context->CreateDoubleScalar(value);
}
static std::unique_ptr<AbstractTensorInterface> CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateDoubleTensor(dim_sizes);
}
static const char* ConvertScalar(PyObject* v, double* out) { static const char* ConvertScalar(PyObject* v, double* out) {
return ConvertOneFloat<double>(v, out); return ConvertOneFloat<double>(v, out);
} }
@ -491,7 +530,15 @@ typedef Converter<float> FloatConverter;
template <> template <>
struct ConverterTraits<Eigen::half> { struct ConverterTraits<Eigen::half> {
static const tensorflow::DataType kTypeEnum = DT_HALF; static std::unique_ptr<AbstractTensorInterface> CreateScalar(
TFE_Context* ctx, Eigen::half value) {
return ctx->context->CreateHalfScalar(value);
}
static std::unique_ptr<AbstractTensorInterface> CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateHalfTensor(dim_sizes);
}
static const char* ConvertScalar(PyObject* v, Eigen::half* out) { static const char* ConvertScalar(PyObject* v, Eigen::half* out) {
return ConvertOneFloat<Eigen::half>(v, out); return ConvertOneFloat<Eigen::half>(v, out);
@ -504,7 +551,15 @@ typedef Converter<Eigen::half> NumpyHalfConverter;
template <> template <>
struct ConverterTraits<tstring> { struct ConverterTraits<tstring> {
static const tensorflow::DataType kTypeEnum = DT_STRING; static std::unique_ptr<AbstractTensorInterface> CreateScalar(TFE_Context* ctx,
tstring value) {
return ctx->context->CreateStringScalar(value);
}
static std::unique_ptr<AbstractTensorInterface> CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateStringTensor(dim_sizes);
}
static const char* ConvertScalar(PyObject* v, tstring* out) { static const char* ConvertScalar(PyObject* v, tstring* out) {
if (PyBytes_Check(v)) { if (PyBytes_Check(v)) {
@ -563,7 +618,16 @@ bool IsPyDimension(PyObject* obj) {
template <> template <>
struct ConverterTraits<complex128> { struct ConverterTraits<complex128> {
static const tensorflow::DataType kTypeEnum = DT_COMPLEX128; static std::unique_ptr<AbstractTensorInterface> CreateScalar(
TFE_Context* ctx, complex128 value) {
return ctx->context->CreateComplex128Scalar(value);
}
static std::unique_ptr<AbstractTensorInterface> CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateComplex128Tensor(dim_sizes);
}
static const char* ConvertScalar(PyObject* v, complex128* out) { static const char* ConvertScalar(PyObject* v, complex128* out) {
if (PyComplex_Check(v)) { if (PyComplex_Check(v)) {
*out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v)); *out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v));
@ -583,8 +647,15 @@ typedef Converter<complex128> Complex128Converter;
template <> template <>
struct ConverterTraits<bool> { struct ConverterTraits<bool> {
typedef bool Type; static std::unique_ptr<AbstractTensorInterface> CreateScalar(TFE_Context* ctx,
static const tensorflow::DataType kTypeEnum = DT_BOOL; bool value) {
return ctx->context->CreateBoolScalar(value);
}
static std::unique_ptr<AbstractTensorInterface> CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateBoolTensor(dim_sizes);
}
static const char* ConvertScalar(PyObject* v, bool* out) { static const char* ConvertScalar(PyObject* v, bool* out) {
if (v == Py_True) { if (v == Py_True) {
@ -606,13 +677,12 @@ typedef Converter<bool> BoolConverter;
// The two may share underlying storage so changes to one may reflect in the // The two may share underlying storage so changes to one may reflect in the
// other. // other.
TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
tensorflow::TensorHandle* handle; std::unique_ptr<AbstractTensorHandleInterface> handle;
tensorflow::Tensor t; tensorflow::Tensor t;
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t); auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
if (cppstatus.ok()) { if (cppstatus.ok()) {
cppstatus = tensorflow::TensorHandle::CreateLocalHandle( cppstatus = ctx->context->CreateLocalHandle(
std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, std::make_unique<TensorInterface>(std::move(t)), &handle);
ctx->context, &handle);
} }
if (!cppstatus.ok()) { if (!cppstatus.ok()) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
@ -622,8 +692,7 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
.c_str()); .c_str());
return nullptr; return nullptr;
} }
return new TFE_TensorHandle{ return new TFE_TensorHandle{std::move(handle)};
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
} }
} // namespace } // namespace
@ -805,17 +874,16 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
case DT_INVALID: // Only occurs for empty tensors. case DT_INVALID: // Only occurs for empty tensors.
{ {
tensorflow::TensorHandle* h = nullptr; std::unique_ptr<AbstractTensorHandleInterface> handle;
Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
TensorShape(state.inferred_shape)); TensorShape(state.inferred_shape));
status = tensorflow::TensorHandle::CreateLocalHandle( status = ctx->context->CreateLocalHandle(
std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, std::make_unique<TensorInterface>(std::move(t)), &handle);
ctx->context, &h);
if (!status.ok()) { if (!status.ok()) {
PyErr_SetString(PyExc_ValueError, status.error_message().c_str()); PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
return nullptr; return nullptr;
} }
return new TFE_TensorHandle{std::make_unique<TensorHandleInterface>(h)}; return new TFE_TensorHandle{std::move(handle)};
} }
default: default: