diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 4e7ba3943ae..132ab718364 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -703,7 +703,8 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, } while (0); // New server created for new server_def. Unused if updating server_def. - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); tensorflow::GrpcServer* grpc_server = dynamic_cast(context->GetServer()); if (grpc_server == nullptr) { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c25cb264ce7..dc7553971fb 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -28,6 +28,8 @@ tf_cuda_library( "c_api_debug.cc", "c_api_experimental.h", "c_api_internal.h", + "context_interface.cc", + "context_interface.h", "operation_interface.cc", "operation_interface.h", "tensor_handle_interface.h", @@ -95,6 +97,7 @@ filegroup( srcs = [ "c_api_experimental.h", "c_api_internal.h", + "context_interface.h", "dlpack.h", "operation_interface.h", "tensor_handle_interface.h", @@ -109,6 +112,7 @@ tf_cuda_library( name = "c_api_internal", srcs = [ "c_api_experimental.h", + "context_interface.h", "operation_interface.h", "tensor_handle_interface.h", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 94a0a76ada1..8743f9327e7 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -305,7 +305,9 @@ tensorflow::Status CreateRemoteContexts( server_def.default_session_config()); std::vector filtered_device_mask; - ctx->context->FilterDevicesForRemoteWorkers( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->FilterDevicesForRemoteWorkers( remote_worker, base_request.cluster_device_attributes(), &filtered_device_mask); DCHECK_EQ(filtered_device_mask.size(), @@ -388,7 +390,9 @@ tensorflow::Status UpdateRemoteContexts( } std::vector filtered_device_mask; - ctx->context->FilterDevicesForRemoteWorkers( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->FilterDevicesForRemoteWorkers( remote_worker, base_request.cluster_device_attributes(), &filtered_device_mask); DCHECK_EQ(filtered_device_mask.size(), cluster_device_count); @@ -467,7 +471,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // New server created for new server_def. Unused if updating server_def. std::unique_ptr new_server; - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); tensorflow::GrpcServer* grpc_server; if (reset_context) { LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); @@ -696,14 +701,16 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - return new TFE_Context{new tensorflow::EagerContext( - opts->session_options.options, - static_cast( - opts->device_placement_policy), - static_cast(opts->mirroring_policy), - opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), - /*device_mgr_owned*/ true, r, - tensorflow::GetDefaultCustomKernelCreator())}; + return new TFE_Context{std::make_unique( + new tensorflow::EagerContext( + opts->session_options.options, + static_cast( + opts->device_placement_policy), + static_cast( + opts->mirroring_policy), + opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), + /*device_mgr_owned*/ true, r, + tensorflow::GetDefaultCustomKernelCreator()))}; } TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, @@ -714,20 +721,24 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr); - return new TFE_Context{new tensorflow::EagerContext( - opts->session_options.options, - static_cast( - opts->device_placement_policy), - static_cast(opts->mirroring_policy), - opts->async, opts->lazy_remote_inputs_copy, device_mgr, - /*device_mgr_owned*/ false, r, - tensorflow::GetDefaultCustomKernelCreator())}; + return new TFE_Context{std::make_unique( + new tensorflow::EagerContext( + opts->session_options.options, + static_cast( + opts->device_placement_policy), + static_cast( + opts->mirroring_policy), + opts->async, opts->lazy_remote_inputs_copy, device_mgr, + /*device_mgr_owned*/ false, r, + tensorflow::GetDefaultCustomKernelCreator()))}; } void TFE_DeleteContext(TFE_Context* ctx) { // context->RefCountIsOne() should be true here. // TODO(iga): Remove EagerContext refcounting. - ctx->context->Unref(); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->Unref(); delete ctx; } @@ -739,7 +750,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { } void TFE_ContextClearCaches(TFE_Context* ctx) { - ctx->context->ClearCachesAndThreadExecutors(); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->ClearCachesAndThreadExecutors(); } // Set server_def on the context, possibly updating it. @@ -769,8 +782,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, device_filters[i] = tdf.second.device_filters(i); } const string remote_worker = remote_prefix + std::to_string(task_index); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = - ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters); + context->SetRemoteDeviceFilters(remote_worker, device_filters); } } } @@ -789,11 +804,13 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, "TFE_ContextSetServerDef not supported on mobile"); #else // !defined(IS_MOBILE_PLATFORM) tensorflow::ServerDef server_def; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); if (!server_def.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument( "Invalid tensorflow.ServerDef protocol buffer"); return; - } else if (ctx->context->GetContextId() == + } else if (context->GetContextId() == tensorflow::EagerContext::kInvalidContextId) { status->status = tensorflow::errors::InvalidArgument( "Trying to update a context with invalid context id."); @@ -817,7 +834,8 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, "TFE_ContextSetServerDef not supported on mobile"); return false; #else // !defined(IS_MOBILE_PLATFORM) - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); tensorflow::GrpcServer* grpc_server = static_cast(context->GetServer()); @@ -872,13 +890,17 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, #if defined(IS_MOBILE_PLATFORM) status->status = tensorflow::Status::OK(); #else // !defined(IS_MOBILE_PLATFORM) - status->status = ctx->context->SyncExecutors(); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->SyncExecutors(); #endif // !IS_MOBILE_PLATFORM } void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - ctx->context->SetThreadLocalDevicePlacementPolicy( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -887,8 +909,10 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); return static_cast( - ctx->context->GetDevicePlacementPolicy()); + context->GetDevicePlacementPolicy()); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { @@ -1178,7 +1202,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( void (*deallocator)(void* data, size_t len, void* arg), void* deallocator_arg, TF_Status* status) { tensorflow::Device* device = nullptr; - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = context->FindDeviceFromName(device_name, &device); tensorflow::CustomDevice* custom_device = nullptr; if (!status->status.ok()) { @@ -1248,8 +1273,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { - std::unique_ptr new_op( - new TFE_Op{std::make_unique(ctx)}); + std::unique_ptr new_op(new TFE_Op{ctx->context->CreateOperation()}); status->status = new_op->operation->Reset(op_or_function_name, nullptr); if (!status->status.ok()) { new_op.reset(); @@ -1497,7 +1521,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TF_Status* status) { tensorflow::TensorHandle* handle = nullptr; tensorflow::Device* device; - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = context->FindDeviceFromName(device_name, &device); if (!status->status.ok()) { tensorflow::CustomDevice* dev; @@ -1556,29 +1581,41 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - status->status = ctx->context->AddFunctionDef(function_def); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - status->status = ctx->context->AddFunctionDef(function->fdef); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->AddFunctionDef(function->fdef); } void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, TF_Status* status) { - status->status = ctx->context->RemoveFunction(name); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + status->status = context->RemoveFunction(name); } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { - return ctx->context->FindFunctionDef(name) != nullptr; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + return context->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(true); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(false); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(false); } } // extern "C" @@ -1590,7 +1627,8 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { - tensorflow::EagerContext* context = ctx->context; + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = context->Executor().WaitForAllPendingNodes(); if (!status->status.ok()) return; tensorflow::mutex_lock ml(*context->MetadataMu()); @@ -1611,9 +1649,17 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace -void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); } +void TFE_ContextStartStep(TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->StartStep(); +} -void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); } +void TFE_ContextEndStep(TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->EndStep(); +} void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) { auto operation = tensorflow::down_cast( @@ -1793,6 +1839,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, TF_Status* status) { auto custom_device = std::make_unique(ctx, device, device_info, device_name); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); status->status = - ctx->context->RegisterCustomDevice(device_name, std::move(custom_device)); + context->RegisterCustomDevice(device_name, std::move(custom_device)); } diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index afa36fe1210..d3baf174563 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -40,11 +40,15 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(true); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { - ctx->context->SetShouldStoreGraphs(false); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetShouldStoreGraphs(false); } void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell, @@ -474,7 +478,9 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options, void TFE_ContextSetThreadLocalMirroringPolicy( TFE_Context* ctx, TFE_ContextMirroringPolicy policy) { - ctx->context->SetThreadLocalMirroringPolicy( + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetThreadLocalMirroringPolicy( static_cast(policy)); } @@ -483,8 +489,9 @@ void TFE_ContextSetThreadLocalMirroringPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( TFE_Context* ctx) { - return static_cast( - ctx->context->GetMirroringPolicy()); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + return static_cast(context->GetMirroringPolicy()); } void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options, @@ -537,16 +544,22 @@ void TFE_ExecutorClearError(TFE_Executor* executor) { } void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { - ctx->context->SetExecutorForThread(executor->executor()); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + context->SetExecutorForThread(executor->executor()); } TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { - return new TFE_Executor(&ctx->context->Executor()); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + return new TFE_Executor(&context->Executor()); } void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); auto address_space = tensorflow::DeviceNameUtils::AddressSpace( - ctx->context->HostCPU()->parsed_name()); + context->HostCPU()->parsed_name()); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); void* data = tensorflow::port::Malloc(str.length()); str.copy(static_cast(data), str.length(), 0); @@ -565,7 +578,9 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h, void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, TF_Buffer* buf, TF_Status* status) { - auto* function_def = ctx->context->FindFunctionDef(function_name); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(ctx->context); + auto* function_def = context->FindFunctionDef(function_name); if (function_def == nullptr) { status->status = tensorflow::errors::NotFound( "Unable to find FunctionDef with name: ", function_name); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 05b0a143025..5d8f17d2702 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/context_interface.h" #include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -62,7 +63,7 @@ struct TFE_ContextOptions { }; struct TFE_Context { - tensorflow::EagerContext* context; + std::unique_ptr context; }; struct TFE_TensorHandle { diff --git a/tensorflow/c/eager/context_interface.cc b/tensorflow/c/eager/context_interface.cc new file mode 100644 index 00000000000..7e7d656816b --- /dev/null +++ b/tensorflow/c/eager/context_interface.cc @@ -0,0 +1,150 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/context_interface.h" + +#include "tensorflow/c/eager/operation_interface.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/core/framework/tensor_interface.h" +#include "tensorflow/core/platform/casts.h" + +namespace tensorflow { + +std::unique_ptr ContextInterface::CreateInt64Scalar( + int64 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateUint64Scalar( + uint64 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateInt32Scalar( + int32 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateFloatScalar( + float value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateDoubleScalar( + double value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateHalfScalar( + Eigen::half value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateStringScalar( + tstring value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr +ContextInterface::CreateComplex128Scalar(complex128 value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateBoolScalar( + bool value) { + return std::make_unique(Tensor(value)); +} + +std::unique_ptr ContextInterface::CreateInt64Tensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_INT64, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateUint64Tensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_UINT64, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateInt32Tensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_INT32, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateFloatTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_FLOAT, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateDoubleTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_DOUBLE, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateHalfTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_HALF, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateStringTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_STRING, TensorShape(dim_sizes))); +} + +std::unique_ptr +ContextInterface::CreateComplex128Tensor(absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_COMPLEX128, TensorShape(dim_sizes))); +} + +std::unique_ptr ContextInterface::CreateBoolTensor( + absl::Span dim_sizes) { + return std::make_unique( + Tensor(DT_BOOL, TensorShape(dim_sizes))); +} + +Status ContextInterface::CreateLocalHandle( + const std::unique_ptr t, + std::unique_ptr* h) { + Tensor tensor = tensorflow::down_cast(t.get())->Tensor(); + tensorflow::TensorHandle* handle = nullptr; + auto status = + TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/ctx_->HostCPU(), + /*op_device=*/nullptr, ctx_, &handle); + if (!status.ok()) { + return status; + } + *h = std::make_unique(handle); + + return status; +} + +std::unique_ptr +ContextInterface::CreateOperation() { + return std::make_unique(ctx_); +} + +void ContextInterface::ListDevices( + std::vector* devices) { + ctx_->ListDevices(devices); +} + +} // namespace tensorflow diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h new file mode 100644 index 00000000000..bac2b0cd3ec --- /dev/null +++ b/tensorflow/c/eager/context_interface.h @@ -0,0 +1,157 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ +#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ + +#include + +#include "tensorflow/c/eager/operation_interface.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor_interface.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/tstring.h" + +namespace tensorflow { + +// Abstract interface to a context. +// +// A context is responsible for creating key objects such as Tensors, +// TensorHandles & Operations. +class AbstractContextInterface { + public: + virtual ~AbstractContextInterface() {} + + // Scalar creation functions + virtual std::unique_ptr CreateInt64Scalar( + int64 value) = 0; + virtual std::unique_ptr CreateUint64Scalar( + uint64 value) = 0; + virtual std::unique_ptr CreateInt32Scalar( + int32 value) = 0; + virtual std::unique_ptr CreateFloatScalar( + float value) = 0; + virtual std::unique_ptr CreateDoubleScalar( + double value) = 0; + virtual std::unique_ptr CreateHalfScalar( + Eigen::half value) = 0; + virtual std::unique_ptr CreateStringScalar( + tstring value) = 0; + virtual std::unique_ptr CreateComplex128Scalar( + complex128 value) = 0; + virtual std::unique_ptr CreateBoolScalar( + bool value) = 0; + + // Tensor creation functions + virtual std::unique_ptr CreateInt64Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateUint64Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateInt32Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateFloatTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateDoubleTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateHalfTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateStringTensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateComplex128Tensor( + absl::Span dim_sizes) = 0; + virtual std::unique_ptr CreateBoolTensor( + absl::Span dim_sizes) = 0; + + // Create a handle to wrap and manage a Tensor + virtual tensorflow::Status CreateLocalHandle( + const std::unique_ptr t, + std::unique_ptr* handle) = 0; + + // Create an operation to perform op execution + virtual std::unique_ptr CreateOperation() = 0; + + // List attributes of available devices + virtual void ListDevices( + std::vector* devices) = 0; +}; + +// TODO(gjn): Try to move these all to EagerContext and make it implement +// AbstractContextInterface. Currently, this is not so straightforward because +// of various BUILD file dependencies. +class ContextInterface : public AbstractContextInterface { + public: + explicit ContextInterface(EagerContext* ctx) : ctx_(ctx) {} + ~ContextInterface() override {} + + std::unique_ptr CreateInt64Scalar( + int64 value) override; + std::unique_ptr CreateUint64Scalar( + uint64 value) override; + std::unique_ptr CreateInt32Scalar( + int32 value) override; + std::unique_ptr CreateFloatScalar( + float value) override; + std::unique_ptr CreateDoubleScalar( + double value) override; + std::unique_ptr CreateHalfScalar( + Eigen::half value) override; + std::unique_ptr CreateStringScalar( + tensorflow::tstring value) override; + std::unique_ptr CreateComplex128Scalar( + tensorflow::complex128 value) override; + std::unique_ptr CreateBoolScalar( + bool value) override; + + std::unique_ptr CreateInt64Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateUint64Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateInt32Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateFloatTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateDoubleTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateHalfTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateStringTensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateComplex128Tensor( + absl::Span dim_sizes) override; + std::unique_ptr CreateBoolTensor( + absl::Span dim_sizes) override; + + tensorflow::Status CreateLocalHandle( + const std::unique_ptr t, + std::unique_ptr* h) override; + std::unique_ptr CreateOperation() override; + + void ListDevices(std::vector* devices) override; + + // For runtime specific APIs, provide ability to get the underlying context. + EagerContext* Context() const { return ctx_; } + + private: + EagerContext* ctx_; +}; + +inline EagerContext* ContextFromInterface( + const std::unique_ptr& context) { + return down_cast(context.get())->Context(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ diff --git a/tensorflow/c/eager/operation_interface.cc b/tensorflow/c/eager/operation_interface.cc index 5703d3231bd..d34f2e29ac0 100644 --- a/tensorflow/c/eager/operation_interface.cc +++ b/tensorflow/c/eager/operation_interface.cc @@ -26,8 +26,7 @@ limitations under the License. namespace tensorflow { -OperationInterface::OperationInterface(TFE_Context* ctx) - : operation_(ctx->context) {} +OperationInterface::OperationInterface(EagerContext* ctx) : operation_(ctx) {} const string& OperationInterface::DeviceName() const { absl::variant variant_device = diff --git a/tensorflow/c/eager/operation_interface.h b/tensorflow/c/eager/operation_interface.h index 900c5112c08..7dc21a628db 100644 --- a/tensorflow/c/eager/operation_interface.h +++ b/tensorflow/c/eager/operation_interface.h @@ -112,7 +112,7 @@ class OpDef; class OperationInterface : public AbstractOperationInterface { public: - explicit OperationInterface(TFE_Context* ctx); + explicit OperationInterface(EagerContext* ctx); ~OperationInterface() override{}; void Clear() override { operation_.Clear(); } diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 4e75beceb3e..03833368102 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -381,7 +381,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { ->ToTensor(dst); } -Status TensorInterface::ToTensor(Tensor* dst) const { +Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const { if (tensor_.dtype() == DT_RESOURCE) { if (tensor_.dims() != 0) { return InvalidArgument( @@ -389,7 +389,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const { "shape ", tensor_.shape().DebugString()); } - *dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape()); + *dst = tensorflow::Tensor(tensorflow::DT_RESOURCE, tensor_.shape()); if (!dst->scalar()().ParseFromString( string(static_cast(Data()), ByteSize()))) { return InvalidArgument( @@ -414,7 +414,7 @@ Status TensorInterface::ToTensor(Tensor* dst) const { const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; const char* limit = input + src_size; - *dst = Tensor(tensor_.dtype(), tensor_.shape()); + *dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape()); auto dstarray = dst->flat(); for (tensorflow::int64 i = 0; i < num_elements; ++i) { tensorflow::uint64 offset = diff --git a/tensorflow/core/framework/tensor_interface.h b/tensorflow/core/framework/tensor_interface.h index f5d7bf53370..115e1a01b22 100644 --- a/tensorflow/core/framework/tensor_interface.h +++ b/tensorflow/core/framework/tensor_interface.h @@ -54,7 +54,7 @@ namespace tensorflow { class TensorInterface : public AbstractTensorInterface { public: TensorInterface() {} - explicit TensorInterface(Tensor t) : tensor_(std::move(t)) {} + explicit TensorInterface(tensorflow::Tensor t) : tensor_(std::move(t)) {} ~TensorInterface() override {} TF_DataType Type() const override; @@ -66,12 +66,16 @@ class TensorInterface : public AbstractTensorInterface { bool IsAligned() const override; bool CanMove() const override; - Status ToTensor(Tensor* dst) const; + Status ToTensor(tensorflow::Tensor* dst) const; Status BitcastFrom(const TensorInterface& from, TF_DataType type, const int64_t* new_dims, int num_new_dims); + // TODO(gjn): This is not a very generic interface, but is needed for specific + // use cases. + tensorflow::Tensor Tensor() { return tensor_; } + private: - Tensor tensor_; + tensorflow::Tensor tensor_; }; } // namespace tensorflow diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index a3f9b0bed5c..a120c1ccdd9 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -75,7 +75,7 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { std::unique_ptr op = ReleaseThreadLocalOp(ctx); if (!op) { - op.reset(new TFE_Op{std::make_unique(ctx)}); + op.reset(new TFE_Op{ctx->context->CreateOperation()}); } status->status = op->operation->Reset(op_or_function_name, raw_device_name); if (!status->status.ok()) { diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 5faa07baf94..2fc70b8c6b6 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -191,18 +191,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { // Prepare the argument. PyObject* args = nullptr; - TFE_Context* ctx = nullptr; std::unique_ptr new_executor = nullptr; EagerExecutor* old_executor = nullptr; if (call->eager) { // See FuncRegistry._ctx. - ctx = reinterpret_cast(PyCapsule_GetPointer( + TFE_Context* ctx = reinterpret_cast(PyCapsule_GetPointer( PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); CHECK_NE(ctx, nullptr); - TF_RETURN_IF_ERROR(MakeArgTuple(call, ctx->context, &args)); + EagerContext* context = ContextFromInterface(ctx->context); + TF_RETURN_IF_ERROR(MakeArgTuple(call, context, &args)); new_executor.reset(new EagerExecutor(call->eager_async)); - old_executor = &ctx->context->Executor(); - ctx->context->SetExecutorForThread(new_executor.get()); + old_executor = &context->Executor(); + context->SetExecutorForThread(new_executor.get()); } else { TF_RETURN_IF_ERROR(MakeArgTuple(call, nullptr, &args)); } @@ -236,8 +236,11 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } if (new_executor != nullptr) { + TFE_Context* ctx = reinterpret_cast(PyCapsule_GetPointer( + PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); + EagerContext* context = ContextFromInterface(ctx->context); s.Update(new_executor->WaitForAllPendingNodes()); - ctx->context->SetExecutorForThread(old_executor); + context->SetExecutorForThread(old_executor); } TF_RETURN_IF_ERROR(s); diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index e81102847ea..42d36bc4d51 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -17,12 +17,12 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" @@ -278,29 +278,26 @@ struct Converter { static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state, TFE_TensorHandle** h, const char** error) { // TODO(josh11b): Allocator & attributes - // TODO(gjn): Use optimized scalar constructors when possible. - Tensor result(ConverterTraits::kTypeEnum, - TensorShape(state->inferred_shape)); + std::unique_ptr tensor; if (state->inferred_shape.empty()) { /* Scalar case */ T value; auto scalar = ZeroDimArrayToScalar(obj, state); *error = ConverterTraits::ConvertScalar(scalar, &value); Py_DECREF(scalar); if (*error != nullptr) return errors::InvalidArgument(*error); - result.scalar()() = value; + tensor = ConverterTraits::CreateScalar(ctx, value); } else { - T* buf = result.flat().data(); - *error = Helper(obj, 0, state, &buf); - if (*error != nullptr) return errors::InvalidArgument(*error); + tensor = ConverterTraits::CreateTensor(ctx, state->inferred_shape); + if (tensor->NumElements() > 0) { + T* buf = static_cast(tensor->Data()); + *error = Helper(obj, 0, state, &buf); + if (*error != nullptr) return errors::InvalidArgument(*error); + } } - tensorflow::TensorHandle* handle = nullptr; - auto status = tensorflow::TensorHandle::CreateLocalHandle( - std::move(result), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, - ctx->context, &handle); - if (!status.ok()) { - return status; - } - *h = new TFE_TensorHandle{std::make_unique(handle)}; + std::unique_ptr handle; + TF_RETURN_IF_ERROR( + ctx->context->CreateLocalHandle(std::move(tensor), &handle)); + *h = new TFE_TensorHandle{std::move(handle)}; return Status::OK(); } }; @@ -309,7 +306,15 @@ struct Converter { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_INT64; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + int64 value) { + return ctx->context->CreateInt64Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateInt64Tensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, int64* out) { #if PY_MAJOR_VERSION < 3 @@ -342,7 +347,15 @@ typedef Converter Int64Converter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_UINT64; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + uint64 value) { + return ctx->context->CreateUint64Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateUint64Tensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, uint64* out) { #if PY_MAJOR_VERSION < 3 @@ -372,7 +385,15 @@ typedef Converter UInt64Converter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_INT32; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + int32 value) { + return ctx->context->CreateInt32Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateInt32Tensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, int32* out) { int64 i; @@ -472,7 +493,16 @@ static const char* ConvertOneFloat(PyObject* v, T* out) { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_FLOAT; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + float value) { + return ctx->context->CreateFloatScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateFloatTensor(dim_sizes); + } + static const char* ConvertScalar(PyObject* v, float* out) { return ConvertOneFloat(v, out); } @@ -480,7 +510,16 @@ struct ConverterTraits { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_DOUBLE; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + double value) { + return ctx->context->CreateDoubleScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateDoubleTensor(dim_sizes); + } + static const char* ConvertScalar(PyObject* v, double* out) { return ConvertOneFloat(v, out); } @@ -491,7 +530,15 @@ typedef Converter FloatConverter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_HALF; + static std::unique_ptr CreateScalar( + TFE_Context* ctx, Eigen::half value) { + return ctx->context->CreateHalfScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateHalfTensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, Eigen::half* out) { return ConvertOneFloat(v, out); @@ -504,7 +551,15 @@ typedef Converter NumpyHalfConverter; template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_STRING; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + tstring value) { + return ctx->context->CreateStringScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateStringTensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, tstring* out) { if (PyBytes_Check(v)) { @@ -563,7 +618,16 @@ bool IsPyDimension(PyObject* obj) { template <> struct ConverterTraits { - static const tensorflow::DataType kTypeEnum = DT_COMPLEX128; + static std::unique_ptr CreateScalar( + TFE_Context* ctx, complex128 value) { + return ctx->context->CreateComplex128Scalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateComplex128Tensor(dim_sizes); + } + static const char* ConvertScalar(PyObject* v, complex128* out) { if (PyComplex_Check(v)) { *out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v)); @@ -583,8 +647,15 @@ typedef Converter Complex128Converter; template <> struct ConverterTraits { - typedef bool Type; - static const tensorflow::DataType kTypeEnum = DT_BOOL; + static std::unique_ptr CreateScalar(TFE_Context* ctx, + bool value) { + return ctx->context->CreateBoolScalar(value); + } + + static std::unique_ptr CreateTensor( + TFE_Context* ctx, absl::Span dim_sizes) { + return ctx->context->CreateBoolTensor(dim_sizes); + } static const char* ConvertScalar(PyObject* v, bool* out) { if (v == Py_True) { @@ -606,13 +677,12 @@ typedef Converter BoolConverter; // The two may share underlying storage so changes to one may reflect in the // other. TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { - tensorflow::TensorHandle* handle; + std::unique_ptr handle; tensorflow::Tensor t; auto cppstatus = tensorflow::NdarrayToTensor(obj, &t); if (cppstatus.ok()) { - cppstatus = tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, - ctx->context, &handle); + cppstatus = ctx->context->CreateLocalHandle( + std::make_unique(std::move(t)), &handle); } if (!cppstatus.ok()) { PyErr_SetString(PyExc_ValueError, @@ -622,8 +692,7 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { .c_str()); return nullptr; } - return new TFE_TensorHandle{ - std::make_unique(handle)}; + return new TFE_TensorHandle{std::move(handle)}; } } // namespace @@ -805,17 +874,16 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, case DT_INVALID: // Only occurs for empty tensors. { - tensorflow::TensorHandle* h = nullptr; + std::unique_ptr handle; Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, TensorShape(state.inferred_shape)); - status = tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, - ctx->context, &h); + status = ctx->context->CreateLocalHandle( + std::make_unique(std::move(t)), &handle); if (!status.ok()) { PyErr_SetString(PyExc_ValueError, status.error_message().c_str()); return nullptr; } - return new TFE_TensorHandle{std::make_unique(h)}; + return new TFE_TensorHandle{std::move(handle)}; } default: