From a31ffb72e178ed072d119b604f45e09df92eac18 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Thu, 26 Mar 2020 20:41:15 -0700 Subject: [PATCH] Add AbstractContextInterface to abstract runtime We want to be able to switch between different runtime backends. In order to do that we introduce the AbstractContextInterface for different runtimes to implement. This interface handles the creation of keyobjects such as Tensors, TensorHandles, Operations as well as device management. PiperOrigin-RevId: 303251247 Change-Id: Ib37c5c7bb3c49d418ad6d6d71fb70f6c2063f569 --- tensorflow/c/c_api_experimental.cc | 3 +- tensorflow/c/eager/BUILD | 4 + tensorflow/c/eager/c_api.cc | 130 ++++++++++----- tensorflow/c/eager/c_api_experimental.cc | 33 ++-- tensorflow/c/eager/c_api_internal.h | 3 +- tensorflow/c/eager/context_interface.cc | 150 ++++++++++++++++++ tensorflow/c/eager/context_interface.h | 157 +++++++++++++++++++ tensorflow/c/eager/operation_interface.cc | 3 +- tensorflow/c/eager/operation_interface.h | 2 +- tensorflow/c/tf_tensor.cc | 6 +- tensorflow/core/framework/tensor_interface.h | 10 +- tensorflow/python/eager/pywrap_tfe_src.cc | 2 +- tensorflow/python/lib/core/py_func.cc | 15 +- tensorflow/python/lib/core/py_seq_tensor.cc | 144 ++++++++++++----- 14 files changed, 556 insertions(+), 106 deletions(-) create mode 100644 tensorflow/c/eager/context_interface.cc create mode 100644 tensorflow/c/eager/context_interface.h diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 4e7ba3943ae..132ab718364 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -703,7 +703,8 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, } while (0); // New server created for new server_def. Unused if updating server_def. - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); tensorflow::GrpcServer* grpc_server = dynamic_cast(context->GetServer()); if (grpc_server == nullptr) { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c25cb264ce7..dc7553971fb 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -28,6 +28,8 @@ tf_cuda_library( "c_api_debug.cc", "c_api_experimental.h", "c_api_internal.h", + "context_interface.cc", + "context_interface.h", "operation_interface.cc", "operation_interface.h", "tensor_handle_interface.h", @@ -95,6 +97,7 @@ filegroup( srcs = [ "c_api_experimental.h", "c_api_internal.h", + "context_interface.h", "dlpack.h", "operation_interface.h", "tensor_handle_interface.h", @@ -109,6 +112,7 @@ tf_cuda_library( name = "c_api_internal", srcs = [ "c_api_experimental.h", + "context_interface.h", "operation_interface.h", "tensor_handle_interface.h", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 94a0a76ada1..8743f9327e7 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -305,7 +305,9 @@ tensorflow::Status CreateRemoteContexts( server_def.default_session_config()); std::vector filtered_device_mask; - ctx->context->FilterDevicesForRemoteWorkers( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->FilterDevicesForRemoteWorkers( remote_worker, base_request.cluster_device_attributes(), &filtered_device_mask); DCHECK_EQ(filtered_device_mask.size(), @@ -388,7 +390,9 @@ tensorflow::Status UpdateRemoteContexts( } std::vector filtered_device_mask; - ctx->context->FilterDevicesForRemoteWorkers( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->FilterDevicesForRemoteWorkers( remote_worker, base_request.cluster_device_attributes(), &filtered_device_mask); DCHECK_EQ(filtered_device_mask.size(), cluster_device_count); @@ -467,7 +471,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // New server created for new server_def. Unused if updating server_def. std::unique_ptr new_server; - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); tensorflow::GrpcServer* grpc_server; if (reset_context) { LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); @@ -696,14 +701,16 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - return new TFE_Context{new tensorflow::EagerContext( - opts->session_options.options, - static_cast( - opts->device_placement_policy), - static_cast(opts->mirroring_policy), - opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), - /*device_mgr_owned*/ true, r, - tensorflow::GetDefaultCustomKernelCreator())}; + return new TFE_Context{std::make_unique( + new tensorflow::EagerContext( + opts->session_options.options, + static_cast( + opts->device_placement_policy), + static_cast( + opts->mirroring_policy), + opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), + /*device_mgr_owned*/ true, r, + tensorflow::GetDefaultCustomKernelCreator()))}; } TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, @@ -714,20 +721,24 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr); - return new TFE_Context{new tensorflow::EagerContext( - opts->session_options.options, - static_cast( - opts->device_placement_policy), - static_cast(opts->mirroring_policy), - opts->async, opts->lazy_remote_inputs_copy, device_mgr, - /*device_mgr_owned*/ false, r, - tensorflow::GetDefaultCustomKernelCreator())}; + return new TFE_Context{std::make_unique( + new tensorflow::EagerContext( + opts->session_options.options, + static_cast( + opts->device_placement_policy), + static_cast( + opts->mirroring_policy), + opts->async, opts->lazy_remote_inputs_copy, device_mgr, + /*device_mgr_owned*/ false, r, + tensorflow::GetDefaultCustomKernelCreator()))}; } void TFE_DeleteContext(TFE_Context* ctx) { // context->RefCountIsOne() should be true here. // TODO(iga): Remove EagerContext refcounting. - ctx->context->Unref(); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->Unref(); delete ctx; } @@ -739,7 +750,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { } void TFE_ContextClearCaches(TFE_Context* ctx) { - ctx->context->ClearCachesAndThreadExecutors(); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->ClearCachesAndThreadExecutors(); } // Set server_def on the context, possibly updating it. @@ -769,8 +782,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, device_filters[i] = tdf.second.device_filters(i); } const string remote_worker = remote_prefix + std::to_string(task_index); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = - ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters); + context->SetRemoteDeviceFilters(remote_worker, device_filters); } } } @@ -789,11 +804,13 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, "TFE_ContextSetServerDef not supported on mobile"); #else // !defined(IS_MOBILE_PLATFORM) tensorflow::ServerDef server_def; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); if (!server_def.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument( "Invalid tensorflow.ServerDef protocol buffer"); return; - } else if (ctx->context->GetContextId() == + } else if (context->GetContextId() == tensorflow::EagerContext::kInvalidContextId) { status->status = tensorflow::errors::InvalidArgument( "Trying to update a context with invalid context id."); @@ -817,7 +834,8 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, "TFE_ContextSetServerDef not supported on mobile"); return false; #else // !defined(IS_MOBILE_PLATFORM) - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); tensorflow::GrpcServer* grpc_server = static_cast(context->GetServer()); @@ -872,13 +890,17 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, #if defined(IS_MOBILE_PLATFORM) status->status = tensorflow::Status::OK(); #else // !defined(IS_MOBILE_PLATFORM) - status->status = ctx->context->SyncExecutors(); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->SyncExecutors(); #endif // !IS_MOBILE_PLATFORM } void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - ctx->context->SetThreadLocalDevicePlacementPolicy( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -887,8 +909,10 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); return static_cast( - ctx->context->GetDevicePlacementPolicy()); + context->GetDevicePlacementPolicy()); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { @@ -1178,7 +1202,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( void (*deallocator)(void* data, size_t len, void* arg), void* deallocator_arg, TF_Status* status) { tensorflow::Device* device = nullptr; - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = context->FindDeviceFromName(device_name, &device); tensorflow::CustomDevice* custom_device = nullptr; if (!status->status.ok()) { @@ -1248,8 +1273,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { - std::unique_ptr new_op( - new TFE_Op{std::make_unique(ctx)}); + std::unique_ptr new_op(new TFE_Op{ctx->context->CreateOperation()}); status->status = new_op->operation->Reset(op_or_function_name, nullptr); if (!status->status.ok()) { new_op.reset(); @@ -1497,7 +1521,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TF_Status* status) { tensorflow::TensorHandle* handle = nullptr; tensorflow::Device* device; - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = context->FindDeviceFromName(device_name, &device); if (!status->status.ok()) { tensorflow::CustomDevice* dev; @@ -1556,29 +1581,41 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - status->status = ctx->context->AddFunctionDef(function_def); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - status->status = ctx->context->AddFunctionDef(function->fdef); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->AddFunctionDef(function->fdef); } void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, TF_Status* status) { - status->status = ctx->context->RemoveFunction(name); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->RemoveFunction(name); } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { - return ctx->context->FindFunctionDef(name) != nullptr; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + return context->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(true); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(false); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(false); } } // extern "C" @@ -1590,7 +1627,8 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = context->Executor().WaitForAllPendingNodes(); if (!status->status.ok()) return; tensorflow::mutex_lock ml(*context->MetadataMu()); @@ -1611,9 +1649,17 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace -void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); } +void TFE_ContextStartStep(TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->StartStep(); +} -void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); } +void TFE_ContextEndStep(TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->EndStep(); +} void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) { auto operation = tensorflow::down_cast( @@ -1793,6 +1839,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, TF_Status* status) { auto custom_device = std::make_unique(ctx, device, device_info, device_name); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = - ctx->context->RegisterCustomDevice(device_name, std::move(custom_device)); + context->RegisterCustomDevice(device_name, std::move(custom_device)); } diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index afa36fe1210..d3baf174563 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -40,11 +40,15 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(true); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(false); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(false); } void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell, @@ -474,7 +478,9 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options, void TFE_ContextSetThreadLocalMirroringPolicy( TFE_Context* ctx, TFE_ContextMirroringPolicy policy) { - ctx->context->SetThreadLocalMirroringPolicy( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetThreadLocalMirroringPolicy( static_cast(policy)); } @@ -483,8 +489,9 @@ void TFE_ContextSetThreadLocalMirroringPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( TFE_Context* ctx) { - return static_cast( - ctx->context->GetMirroringPolicy()); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + return static_cast(context->GetMirroringPolicy()); } void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options, @@ -537,16 +544,22 @@ void TFE_ExecutorClearError(TFE_Executor* executor) { } void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { - ctx->context->SetExecutorForThread(executor->executor()); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetExecutorForThread(executor->executor()); } TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { - return new TFE_Executor(&ctx->context->Executor()); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + return new TFE_Executor(&context->Executor()); } void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); auto address_space = tensorflow::DeviceNameUtils::AddressSpace( - ctx->context->HostCPU()->parsed_name()); + context->HostCPU()->parsed_name()); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); void* data = tensorflow::port::Malloc(str.length()); str.copy(static_cast(data), str.length(), 0); @@ -565,7 +578,9 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h, void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, TF_Buffer* buf, TF_Status* status) { - auto* function_def = ctx->context->FindFunctionDef(function_name); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + auto* function_def = context->FindFunctionDef(function_name); if (function_def == nullptr) { status->status = tensorflow::errors::NotFound( "Unable to find FunctionDef with name: ", function_name); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 05b0a143025..5d8f17d2702 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/context_interface.h" #include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -62,7 +63,7 @@ struct TFE_ContextOptions { }; struct TFE_Context { - tensorflow::EagerContext* context; + std::unique_ptr context; }; struct TFE_TensorHandle { diff --git a/tensorflow/c/eager/context_interface.cc b/tensorflow/c/eager/context_interface.cc new file mode 100644 index 00000000000..7e7d656816b --- /dev/null +++ b/tensorflow/c/eager/context_interface.cc @@ -0,0 +1,150 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/context_interface.h" + +#include "tensorflow/c/eager/operation_interface.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/core/framework/tensor_interface.h" +#include "tensorflow/core/platform/casts.h" + +namespace tensorflow { + +std::unique_ptr ContextInterface::CreateInt64Scalar( + int64 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateUint64Scalar( + uint64 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateInt32Scalar( + int32 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateFloatScalar( + float value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateDoubleScalar( + double value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateHalfScalar( + Eigen::half value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateStringScalar( + tstring value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr +ContextInterface::CreateComplex128Scalar(complex128 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateBoolScalar( + bool value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateInt64Tensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_INT64, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateUint64Tensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_UINT64, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateInt32Tensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_INT32, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateFloatTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_FLOAT, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateDoubleTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_DOUBLE, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateHalfTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_HALF, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateStringTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_STRING, TensorShape(dim_sizes))); +} + +std::unique_ptr +ContextInterface::CreateComplex128Tensor(absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_COMPLEX128, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateBoolTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_BOOL, TensorShape(dim_sizes))); +} + +Status ContextInterface::CreateLocalHandle( + const std::unique_ptr t, + std::unique_ptr* h) { + Tensor tensor = tensorflow::down_cast(t.get())->Tensor(); + tensorflow::TensorHandle* handle = nullptr; + auto status = + TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/ctx_->HostCPU(), + /*op_device=*/nullptr, ctx_, &handle); + if (!status.ok()) { + return status; + } + *h = std::make_unique(handle); + + return status; +} + +std::unique_ptr +ContextInterface::CreateOperation() { + return std::make_unique(ctx_); +} + +void ContextInterface::ListDevices( + std::vector* devices) { + ctx_->ListDevices(devices); +} + +} // namespace tensorflow diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h new file mode 100644 index 00000000000..bac2b0cd3ec --- /dev/null +++ b/tensorflow/c/eager/context_interface.h @@ -0,0 +1,157 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ +#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ + +#include + +#include "tensorflow/c/eager/operation_interface.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor_interface.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/tstring.h" + +namespace tensorflow { + +// Abstract interface to a context. +// +// A context is responsible for creating key objects such as Tensors, +// TensorHandles & Operations. +class AbstractContextInterface { + public: + virtual ~AbstractContextInterface() {} + + // Scalar creation functions + virtual std::unique_ptr CreateInt64Scalar( + int64 value) = 0; + virtual std::unique_ptr CreateUint64Scalar( + uint64 value) = 0; + virtual std::unique_ptr CreateInt32Scalar( + int32 value) = 0; + virtual std::unique_ptr CreateFloatScalar( + float value) = 0; + virtual std::unique_ptr CreateDoubleScalar( + double value) = 0; + virtual std::unique_ptr CreateHalfScalar( + Eigen::half value) = 0; + virtual std::unique_ptr CreateStringScalar( + tstring value) = 0; + virtual std::unique_ptr CreateComplex128Scalar( + complex128 value) = 0; + virtual std::unique_ptr CreateBoolScalar( + bool value) = 0; + + // Tensor creation functions + virtual std::unique_ptr CreateInt64Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateUint64Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateInt32Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateFloatTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateDoubleTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateHalfTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateStringTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateComplex128Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateBoolTensor( + absl::Span dim_sizes) = 0; + + // Create a handle to wrap and manage a Tensor + virtual tensorflow::Status CreateLocalHandle( + const std::unique_ptr t, + std::unique_ptr* handle) = 0; + + // Create an operation to perform op execution + virtual std::unique_ptr CreateOperation() = 0; + + // List attributes of available devices + virtual void ListDevices( + std::vector* devices) = 0; +}; + +// TODO(gjn): Try to move these all to EagerContext and make it implement +// AbstractContextInterface. Currently, this is not so straightforward because +// of various BUILD file dependencies. +class ContextInterface : public AbstractContextInterface { + public: + explicit ContextInterface(EagerContext* ctx) : ctx_(ctx) {} + ~ContextInterface() override {} + + std::unique_ptr CreateInt64Scalar( + int64 value) override; + std::unique_ptr CreateUint64Scalar( + uint64 value) override; + std::unique_ptr CreateInt32Scalar( + int32 value) override; + std::unique_ptr CreateFloatScalar( + float value) override; + std::unique_ptr CreateDoubleScalar( + double value) override; + std::unique_ptr CreateHalfScalar( + Eigen::half value) override; + std::unique_ptr CreateStringScalar( + tensorflow::tstring value) override; + std::unique_ptr CreateComplex128Scalar( + tensorflow::complex128 value) override; + std::unique_ptr CreateBoolScalar( + bool value) override; + + std::unique_ptr CreateInt64Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateUint64Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateInt32Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateFloatTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateDoubleTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateHalfTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateStringTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateComplex128Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateBoolTensor( + absl::Span dim_sizes) override; + + tensorflow::Status CreateLocalHandle( + const std::unique_ptr t, + std::unique_ptr* h) override; + std::unique_ptr CreateOperation() override; + + void ListDevices(std::vector* devices) override; + + // For runtime specific APIs, provide ability to get the underlying context. + EagerContext* Context() const { return ctx_; } + + private: + EagerContext* ctx_; +}; + +inline EagerContext* ContextFromInterface( + const std::unique_ptr& context) { + return down_cast(context.get())->Context(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ diff --git a/tensorflow/c/eager/operation_interface.cc b/tensorflow/c/eager/operation_interface.cc index 5703d3231bd..d34f2e29ac0 100644 --- a/tensorflow/c/eager/operation_interface.cc +++ b/tensorflow/c/eager/operation_interface.cc @@ -26,8 +26,7 @@ limitations under the License. namespace tensorflow { -OperationInterface::OperationInterface(TFE_Context* ctx) - : operation_(ctx->context) {} +OperationInterface::OperationInterface(EagerContext* ctx) : operation_(ctx) {} const string& OperationInterface::DeviceName() const { absl::variant variant_device = diff --git a/tensorflow/c/eager/operation_interface.h b/tensorflow/c/eager/operation_interface.h index 900c5112c08..7dc21a628db 100644 --- a/tensorflow/c/eager/operation_interface.h +++ b/tensorflow/c/eager/operation_interface.h @@ -112,7 +112,7 @@ class OpDef; class OperationInterface : public AbstractOperationInterface { public: - explicit OperationInterface(TFE_Context* ctx); + explicit OperationInterface(EagerContext* ctx); ~OperationInterface() override{}; void Clear() override { operation_.Clear(); } diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 4e75beceb3e..03833368102 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -381,7 +381,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { ->ToTensor(dst); } -Status TensorInterface::ToTensor(Tensor* dst) const { +Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const { if (tensor_.dtype() == DT_RESOURCE) { if (tensor_.dims() != 0) { return InvalidArgument( @@ -389,7 +389,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const { "shape ", tensor_.shape().DebugString()); } - *dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape()); + *dst = tensorflow::Tensor(tensorflow::DT_RESOURCE, tensor_.shape()); if (!dst->scalar()().ParseFromString( string(static_cast(Data()), ByteSize()))) { return InvalidArgument( @@ -414,7 +414,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const { const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; const char* limit = input + src_size; - *dst = Tensor(tensor_.dtype(), tensor_.shape()); + *dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape()); auto dstarray = dst->flat(); for (tensorflow::int64 i = 0; i < num_elements; ++i) { tensorflow::uint64 offset = diff --git a/tensorflow/core/framework/tensor_interface.h b/tensorflow/core/framework/tensor_interface.h index f5d7bf53370..115e1a01b22 100644 --- a/tensorflow/core/framework/tensor_interface.h +++ b/tensorflow/core/framework/tensor_interface.h @@ -54,7 +54,7 @@ namespace tensorflow { class TensorInterface : public AbstractTensorInterface { public: TensorInterface() {} - explicit TensorInterface(Tensor t) : tensor_(std::move(t)) {} + explicit TensorInterface(tensorflow::Tensor t) : tensor_(std::move(t)) {} ~TensorInterface() override {} TF_DataType Type() const override; @@ -66,12 +66,16 @@ class TensorInterface : public AbstractTensorInterface { bool IsAligned() const override; bool CanMove() const override; - Status ToTensor(Tensor* dst) const; + Status ToTensor(tensorflow::Tensor* dst) const; Status BitcastFrom(const TensorInterface& from, TF_DataType type, const int64_t* new_dims, int num_new_dims); + // TODO(gjn): This is not a very generic interface, but is needed for specific + // use cases. + tensorflow::Tensor Tensor() { return tensor_; } + private: - Tensor tensor_; + tensorflow::Tensor tensor_; }; } // namespace tensorflow diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index a3f9b0bed5c..a120c1ccdd9 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -75,7 +75,7 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { std::unique_ptr op = ReleaseThreadLocalOp(ctx); if (!op) { - op.reset(new TFE_Op{std::make_unique(ctx)}); + op.reset(new TFE_Op{ctx->context->CreateOperation()}); } status->status = op->operation->Reset(op_or_function_name, raw_device_name); if (!status->status.ok()) { diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 5faa07baf94..2fc70b8c6b6 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -191,18 +191,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { // Prepare the argument. PyObject* args = nullptr; - TFE_Context* ctx = nullptr; std::unique_ptr new_executor = nullptr; EagerExecutor* old_executor = nullptr; if (call->eager) { // See FuncRegistry._ctx. - ctx = reinterpret_cast(PyCapsule_GetPointer( + TFE_Context* ctx = reinterpret_cast(PyCapsule_GetPointer( PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); CHECK_NE(ctx, nullptr); - TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx->context, &args)); + EagerContext* context = ContextFromInterface(ctx->context); + TF_RETURN_IF_ERROR(MakeArgTuple(call, context, &args)); new_executor.reset(new EagerExecutor(call->eager_async)); - old_executor = &ctx->context->Executor(); - ctx->context->SetExecutorForThread(new_executor.get()); + old_executor = &context->Executor(); + context->SetExecutorForThread(new_executor.get()); } else { TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args)); } @@ -236,8 +236,11 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } if (new_executor != nullptr) { + TFE_Context* ctx = reinterpret_cast(PyCapsule_GetPointer( + PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); + EagerContext* context = ContextFromInterface(ctx->context); s.Update(new_executor->WaitForAllPendingNodes()); - ctx->context->SetExecutorForThread(old_executor); + context->SetExecutorForThread(old_executor); } TF_RETURN_IF_ERROR(s); diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index e81102847ea..42d36bc4d51 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -17,12 +17,12 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" @@ -278,29 +278,26 @@ struct Converter { static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state, TFE_TensorHandle** h, const char** error) { // TODO(josh11b): Allocator & attributes - // TODO(gjn): Use optimized scalar constructors when possible. - Tensor result(ConverterTraits::kTypeEnum, - TensorShape(state->inferred_shape)); + std::unique_ptr tensor; if (state->inferred_shape.empty()) { /* Scalar case */ T value; auto scalar = ZeroDimArrayToScalar(obj, state); *error = ConverterTraits::ConvertScalar(scalar, &value); Py_DECREF(scalar); if (*error != nullptr) return errors::InvalidArgument(*error); - result.scalar()() = value; + tensor = ConverterTraits::CreateScalar(ctx, value); } else { - T* buf = result.flat().data(); - *error = Helper(obj, 0, state, &buf); - if (*error != nullptr) return errors::InvalidArgument(*error); + tensor = ConverterTraits::CreateTensor(ctx, state->inferred_shape); + if (tensor->NumElements() > 0) { + T* buf = static_cast(tensor->Data()); + *error = Helper(obj, 0, state, &buf); + if (*error != nullptr) return errors::InvalidArgument(*error); + } } - tensorflow::TensorHandle* handle = nullptr; - auto status = tensorflow::TensorHandle::CreateLocalHandle( - std::move(result), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, - ctx->context, &handle); - if (!status.ok()) { - return status; - } - *h = new TFE_TensorHandle{std::make_unique(handle)}; + std::unique_ptr handle; + TF_RETURN_IF_ERROR( + ctx->context->CreateLocalHandle(std::move(tensor), &handle)); + *h = new TFE_TensorHandle{std::move(handle)}; return Status::OK(); } }; @@ -309,7 +306,15 @@ struct Converter { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_INT64; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + int64 value) { + return ctx->context->CreateInt64Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateInt64Tensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, int64* out) { #if PY_MAJOR_VERSION < 3 @@ -342,7 +347,15 @@ typedef Converter Int64Converter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_UINT64; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + uint64 value) { + return ctx->context->CreateUint64Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateUint64Tensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, uint64* out) { #if PY_MAJOR_VERSION < 3 @@ -372,7 +385,15 @@ typedef Converter UInt64Converter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_INT32; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + int32 value) { + return ctx->context->CreateInt32Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateInt32Tensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, int32* out) { int64 i; @@ -472,7 +493,16 @@ static const char* ConvertOneFloat(PyObject* v, T* out) { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_FLOAT; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + float value) { + return ctx->context->CreateFloatScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateFloatTensor(dim_sizes); + } + static const char* ConvertScalar(PyObject* v, float* out) { return ConvertOneFloat(v, out); } @@ -480,7 +510,16 @@ struct ConverterTraits { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_DOUBLE; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + double value) { + return ctx->context->CreateDoubleScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateDoubleTensor(dim_sizes); + } + static const char* ConvertScalar(PyObject* v, double* out) { return ConvertOneFloat(v, out); } @@ -491,7 +530,15 @@ typedef Converter FloatConverter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_HALF; + static std::unique_ptr CreateScalar( + TFE_Context* ctx, Eigen::half value) { + return ctx->context->CreateHalfScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateHalfTensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, Eigen::half* out) { return ConvertOneFloat(v, out); @@ -504,7 +551,15 @@ typedef Converter NumpyHalfConverter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_STRING; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + tstring value) { + return ctx->context->CreateStringScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateStringTensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, tstring* out) { if (PyBytes_Check(v)) { @@ -563,7 +618,16 @@ bool IsPyDimension(PyObject* obj) { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_COMPLEX128; + static std::unique_ptr CreateScalar( + TFE_Context* ctx, complex128 value) { + return ctx->context->CreateComplex128Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateComplex128Tensor(dim_sizes); + } + static const char* ConvertScalar(PyObject* v, complex128* out) { if (PyComplex_Check(v)) { *out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v)); @@ -583,8 +647,15 @@ typedef Converter Complex128Converter; template <> struct ConverterTraits { - typedef bool Type; - static const tensorflow::DataType kTypeEnum = DT_BOOL; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + bool value) { + return ctx->context->CreateBoolScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateBoolTensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, bool* out) { if (v == Py_True) { @@ -606,13 +677,12 @@ typedef Converter BoolConverter; // The two may share underlying storage so changes to one may reflect in the // other. TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { - tensorflow::TensorHandle* handle; + std::unique_ptr handle; tensorflow::Tensor t; auto cppstatus = tensorflow::NdarrayToTensor(obj, &t); if (cppstatus.ok()) { - cppstatus = tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, - ctx->context, &handle); + cppstatus = ctx->context->CreateLocalHandle( + std::make_unique(std::move(t)), &handle); } if (!cppstatus.ok()) { PyErr_SetString(PyExc_ValueError, @@ -622,8 +692,7 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { .c_str()); return nullptr; } - return new TFE_TensorHandle{ - std::make_unique(handle)}; + return new TFE_TensorHandle{std::move(handle)}; } } // namespace @@ -805,17 +874,16 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, case DT_INVALID: // Only occurs for empty tensors. { - tensorflow::TensorHandle* h = nullptr; + std::unique_ptr handle; Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, TensorShape(state.inferred_shape)); - status = tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, - ctx->context, &h); + status = ctx->context->CreateLocalHandle( + std::make_unique(std::move(t)), &handle); if (!status.ok()) { PyErr_SetString(PyExc_ValueError, status.error_message().c_str()); return nullptr; } - return new TFE_TensorHandle{std::make_unique(h)}; + return new TFE_TensorHandle{std::move(handle)}; } default: