diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index aafa89cc3d8..1c4c0d1e06a 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -334,6 +334,9 @@ tf_cuda_library( ":checkpoint_reader", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:tfe_context_internal", + "//tensorflow/c/eager:tfe_op_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/compiler/jit:flags", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -729,3 +732,11 @@ tf_cuda_library( ], alwayslink = 1, ) + +cc_library( + name = "conversion_macros", + hdrs = [ + "conversion_macros.h", + ], + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index eb7bd61ee89..e623f30b98c 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -21,6 +21,9 @@ limitations under the License. #include "tensorflow/c/checkpoint_reader.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -686,8 +689,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type, std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); status->status = tensorflow::Status::OK(); - return new TFE_TensorHandle{ - tensorflow::TensorHandle::CreateLocalHandle(tensor)}; + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); } namespace { @@ -708,7 +710,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, // New server created for new server_def. Unused if updating server_def. tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::GrpcServer* grpc_server = dynamic_cast(context->GetServer()); if (grpc_server == nullptr) { @@ -822,14 +824,13 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, const int num_inputs = input_shapes->num_items; NodeDef node_def; - node_def.set_name(tfe_op->operation->Name()); - node_def.set_op(tfe_op->operation->Name()); + tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op); + node_def.set_name(op->Name()); + node_def.set_op(op->Name()); for (int i = 0; i < num_inputs; ++i) { node_def.add_input("dummy_input"); } - OperationFromInterface(tfe_op->operation) - ->Attrs() - .FillAttrValueMap(node_def.mutable_attr()); + OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr()); const tensorflow::OpRegistrationData* op_reg_data; status->status = diff --git a/tensorflow/c/conversion_macros.h b/tensorflow/c/conversion_macros.h new file mode 100644 index 00000000000..ce8adfadb26 --- /dev/null +++ b/tensorflow/c/conversion_macros.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_ +#define TENSORFLOW_C_CONVERSION_MACROS_H_ + +#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \ + inline cpp_impl *unwrap(wrapper *w) { \ + return reinterpret_cast(w); \ + } \ + \ + inline const cpp_impl *unwrap(const wrapper *w) { \ + return reinterpret_cast(w); \ + } \ + \ + inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast(i); } + +#endif // TENSORFLOW_C_CONVERSION_MACROS_H_ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c20ec9c4769..42a31444380 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -50,7 +50,6 @@ tf_cuda_library( ":tfe_tensor_debug_info_internal", ":tfe_tensorhandle_internal", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "//tensorflow/c:c_api", @@ -110,13 +109,10 @@ filegroup( "operation_interface.h", "tensor_handle_interface.h", "tfe_cancellation_manager_internal.h", - "tfe_context_internal.h", "tfe_executor_internal.h", "tfe_monitoring_internal.h", "tfe_op_attrs_internal.h", - "tfe_op_internal.h", "tfe_tensor_debug_info_internal.h", - "tfe_tensorhandle_internal.h", ], visibility = [ "//tensorflow/core:__pkg__", @@ -205,6 +201,7 @@ cc_library( ], deps = [ ":context_interface", + "//tensorflow/c:conversion_macros", ], ) @@ -249,8 +246,6 @@ cc_library( "//tensorflow:internal", ], deps = [ - ":tfe_context_internal", - ":tfe_op_internal", "//tensorflow/c:tf_status", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:attr_builder", @@ -265,6 +260,7 @@ cc_library( ], deps = [ ":operation_interface", + "//tensorflow/c:conversion_macros", ], ) @@ -287,6 +283,7 @@ cc_library( ], deps = [ ":tensor_handle_interface", + "//tensorflow/c:conversion_macros", ], ) @@ -327,6 +324,8 @@ tf_cuda_cc_test( ":c_api_experimental", ":c_api_internal", ":c_api_test_util", + ":tfe_op_internal", + ":tfe_tensorhandle_internal", "//tensorflow/c:c_test_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -351,6 +350,7 @@ tf_cuda_cc_test( ":c_api_experimental", ":c_api_internal", ":c_api_test_util", + ":tfe_tensorhandle_internal", "//tensorflow/c:c_test_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -384,6 +384,9 @@ tf_cuda_library( "//conditions:default": [ ":c_api", ":c_api_internal", + ":tfe_context_internal", + ":tfe_op_internal", + ":tfe_tensorhandle_internal", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", @@ -548,8 +551,9 @@ cc_library( deps = [ ":c_api", ":c_api_experimental", - ":c_api_internal", + ":tfe_tensorhandle_internal", "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_status_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index df950d7fc61..0a64f3c91a5 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -26,7 +26,6 @@ limitations under the License. // clang-format on #include "absl/algorithm/container.h" -#include "absl/container/fixed_array.h" #include "absl/memory/memory.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" @@ -34,6 +33,9 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_tensor_internal.h" #ifdef PLATFORM_GOOGLE #include "tensorflow/c/eager/c_api_tfrt.h" @@ -298,7 +300,7 @@ tensorflow::Status CreateRemoteContexts( std::vector filtered_device_mask; tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->FilterDevicesForRemoteWorkers( remote_worker, base_request.cluster_device_attributes(), &filtered_device_mask); @@ -383,7 +385,7 @@ tensorflow::Status UpdateRemoteContexts( std::vector filtered_device_mask; tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->FilterDevicesForRemoteWorkers( remote_worker, base_request.cluster_device_attributes(), &filtered_device_mask); @@ -464,7 +466,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // New server created for new server_def. Unused if updating server_def. std::unique_ptr new_server; tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::GrpcServer* grpc_server; if (reset_context) { LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); @@ -684,7 +686,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { if (opts->use_tfrt) { #ifdef PLATFORM_GOOGLE status->status = tensorflow::Status::OK(); - return new TFE_Context{new tfrt::ContextInterface()}; + return tensorflow::wrap(new tfrt::ContextInterface()); #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); return nullptr; @@ -701,14 +703,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - return new TFE_Context{new tensorflow::EagerContext( + return tensorflow::wrap(new tensorflow::EagerContext( opts->session_options.options, static_cast( opts->device_placement_policy), static_cast(opts->mirroring_policy), opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), /*device_mgr_owned*/ true, r, - tensorflow::GetDefaultCustomKernelCreator())}; + tensorflow::GetDefaultCustomKernelCreator())); } TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, @@ -719,14 +721,14 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr); - return new TFE_Context{new tensorflow::EagerContext( + return tensorflow::wrap(new tensorflow::EagerContext( opts->session_options.options, static_cast( opts->device_placement_policy), static_cast(opts->mirroring_policy), opts->async, opts->lazy_remote_inputs_copy, device_mgr, /*device_mgr_owned*/ false, r, - tensorflow::GetDefaultCustomKernelCreator())}; + tensorflow::GetDefaultCustomKernelCreator())); } void TFE_DeleteContext(TFE_Context* ctx) { @@ -734,22 +736,19 @@ void TFE_DeleteContext(TFE_Context* ctx) { return; } - // context->RefCountIsOne() should be true here. - // TODO(iga): Remove EagerContext refcounting. - ctx->context->Release(); - - delete ctx; + // ctx->RefCountIsOne() should be true here. + tensorflow::unwrap(ctx)->Release(); } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* l = new TF_DeviceList; - ctx->context->ListDevices(&l->response); + tensorflow::unwrap(ctx)->ListDevices(&l->response); return l; } void TFE_ContextClearCaches(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->ClearCachesAndThreadExecutors(); } @@ -772,7 +771,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, if (server_def.has_cluster_device_filters()) { const auto& cdf = server_def.cluster_device_filters(); for (const auto& jdf : cdf.jobs()) { - const string& remote_prefix = "/job:" + jdf.name() + "/task:"; + const string remote_prefix = "/job:" + jdf.name() + "/task:"; for (const auto& tdf : jdf.tasks()) { const int32_t task_index = tdf.first; std::vector device_filters(tdf.second.device_filters_size()); @@ -781,7 +780,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, } const string remote_worker = remote_prefix + std::to_string(task_index); tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->SetRemoteDeviceFilters(remote_worker, device_filters); } @@ -803,7 +802,7 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, #else // !defined(IS_MOBILE_PLATFORM) tensorflow::ServerDef server_def; tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); if (!server_def.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument( "Invalid tensorflow.ServerDef protocol buffer"); @@ -833,7 +832,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, return false; #else // !defined(IS_MOBILE_PLATFORM) tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::GrpcServer* grpc_server = static_cast(context->GetServer()); @@ -889,7 +888,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, status->status = tensorflow::Status::OK(); #else // !defined(IS_MOBILE_PLATFORM) tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->SyncExecutors(); #endif // !IS_MOBILE_PLATFORM } @@ -897,7 +896,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -908,7 +907,7 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return static_cast( context->GetDevicePlacementPolicy()); } @@ -918,8 +917,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { status->status = tensorflow::TF_TensorToTensor(t, &tensor); if (!status->status.ok()) return nullptr; - return new TFE_TensorHandle{ - tensorflow::TensorHandle::CreateLocalHandle(tensor)}; + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); } void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { @@ -927,84 +925,84 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { tensorflow::profiler::TraceMe activity( "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo); - if (h->handle) { - h->handle->Release(); + if (h) { + tensorflow::unwrap(h)->Release(); } - delete h; } TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { - return static_cast(h->handle->DataType()); + return static_cast(tensorflow::unwrap(h)->DataType()); } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return -1; } int num_dims = -1; - status->status = h->handle->NumDims(&num_dims); + status->status = tensorflow::unwrap(h)->NumDims(&num_dims); return num_dims; } int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return -1; } int64 num_elements = -1; - status->status = h->handle->NumElements(&num_elements); + status->status = tensorflow::unwrap(h)->NumElements(&num_elements); return num_elements; } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return -1; } int64 dim = -1; - status->status = h->handle->Dim(dim_index, &dim); + status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim); return dim; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - return h->handle->DeviceName(&status->status); + return tensorflow::unwrap(h)->DeviceName(&status->status); } const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - return h->handle->BackingDeviceName(&status->status); + return tensorflow::unwrap(h)->BackingDeviceName(&status->status); } TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - return new TFE_TensorHandle{h->handle->Copy()}; + return tensorflow::wrap(tensorflow::unwrap(h)->Copy()); } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - tensorflow::AbstractTensorInterface* t = h->handle->Resolve(&status->status); + tensorflow::AbstractTensorInterface* t = + tensorflow::unwrap(h)->Resolve(&status->status); if (t == nullptr) { return nullptr; } @@ -1013,12 +1011,12 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { } void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } tensorflow::TensorHandle* handle = - tensorflow::TensorHandleFromInterface(h->handle); + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h)); if (VariantDeviceIsCustom(handle->device())) { const tensorflow::Tensor* t; status->status = handle->Tensor(&t); @@ -1054,7 +1052,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( void* deallocator_arg, TF_Status* status) { tensorflow::Device* device = nullptr; tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->FindDeviceFromName(device_name, &device); tensorflow::CustomDevice* custom_device = nullptr; if (!status->status.ok()) { @@ -1080,11 +1078,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( tensorflow::TensorShape(dimvec), buf); buf->Unref(); if (custom_device == nullptr) { - return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), device, device, context)}; + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( + std::move(t), device, device, context)); } else { - return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), custom_device, context)}; + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( + std::move(t), custom_device, context)); } } @@ -1093,12 +1091,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( // bytes of the memory pointed to by the device pointer returned above. size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return 0; } tensorflow::TensorHandle* handle = - tensorflow::TensorHandleFromInterface(h->handle); + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h)); if (handle->Type() != tensorflow::TensorHandle::LOCAL) { status->status = tensorflow::errors::InvalidArgument( "TFE_TensorHandleDeviceMemorySize may not be called on a ", @@ -1115,12 +1113,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { - std::unique_ptr new_op(new TFE_Op{ctx->context->CreateOperation()}); - status->status = new_op->operation->Reset(op_or_function_name, nullptr); + tensorflow::AbstractOperationInterface* new_op = + tensorflow::unwrap(ctx)->CreateOperation(); + status->status = new_op->Reset(op_or_function_name, nullptr); if (!status->status.ok()) { - new_op.reset(); + new_op->Release(); + new_op = nullptr; } - return new_op.release(); + return tensorflow::wrap(new_op); } void TFE_DeleteOp(TFE_Op* op) { @@ -1128,24 +1128,20 @@ void TFE_DeleteOp(TFE_Op* op) { return; } - if (op->operation) { - op->operation->Release(); - } - - delete op; + tensorflow::unwrap(op)->Release(); } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { - status->status = op->operation->SetDeviceName(device_name); + status->status = tensorflow::unwrap(op)->SetDeviceName(device_name); } const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { - return op->operation->DeviceName().c_str(); + return tensorflow::unwrap(op)->DeviceName().c_str(); } void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { #ifdef TENSORFLOW_EAGER_USE_XLA - tensorflow::Status s = op->operation->SetUseXla(enable); + tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable); if (!s.ok()) { LOG(ERROR) << "Could not enable XLA compilation for op: " << s; } @@ -1156,18 +1152,13 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { - status->status = op->operation->AddInput(input->handle); + status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input)); } void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status) { - absl::FixedArray handles( - num_inputs); - for (int i = 0; i < num_inputs; ++i) { - handles[i] = inputs[i]->handle; - } - status->status = - op->operation->AddInputList({handles.data(), handles.size()}); + status->status = tensorflow::unwrap(op)->AddInputList( + {tensorflow::unwrap(inputs), static_cast(num_inputs)}); } TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, @@ -1175,8 +1166,8 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, TF_AttrType ret = TF_ATTR_INT; const tensorflow::AttrTypeMap* attr_types_; bool is_function; - status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(), - &attr_types_, &is_function); + status->status = tensorflow::AttrTypeMapForOp( + tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function); if (!status->status.ok()) { return ret; } @@ -1202,7 +1193,7 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, size_t length) { - auto s = op->operation->SetAttrString( + auto s = tensorflow::unwrap(op)->SetAttrString( attr_name, static_cast(value), length); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; @@ -1210,29 +1201,30 @@ void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, } void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { - auto s = op->operation->SetAttrInt(attr_name, value); + auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } } void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { - auto s = op->operation->SetAttrFloat(attr_name, value); + auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } } void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { - auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true); + auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name, + (value == 0) ? false : true); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } } void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { - auto s = op->operation->SetAttrType(attr_name, - static_cast(value)); + auto s = tensorflow::unwrap(op)->SetAttrType( + attr_name, static_cast(value)); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1240,12 +1232,14 @@ void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, const int num_dims, TF_Status* out_status) { - out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims); + out_status->status = + tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims); } void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value) { - auto s = op->operation->SetAttrFunction(attr_name, value->operation); + auto s = tensorflow::unwrap(op)->SetAttrFunction( + attr_name, tensorflow::unwrap(const_cast(value))); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1253,7 +1247,7 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, const char* data, size_t length) { - auto s = op->operation->SetAttrFunctionName(attr_name, data, length); + auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1264,14 +1258,14 @@ void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, tensorflow::Tensor t; status->status = TF_TensorToTensor(tensor, &t); tensorflow::TensorInterface interface(t); - status->status = op->operation->SetAttrTensor(attr_name, &interface); + status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface); } void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { - auto s = - op->operation->SetAttrStringList(attr_name, values, lengths, num_values); + auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths, + num_values); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1279,7 +1273,8 @@ void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, const float* values, int num_values) { - auto s = op->operation->SetAttrFloatList(attr_name, values, num_values); + auto s = + tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1287,7 +1282,8 @@ void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, const int64_t* values, int num_values) { - auto s = op->operation->SetAttrIntList(attr_name, values, num_values); + auto s = + tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1295,7 +1291,7 @@ void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, const TF_DataType* values, int num_values) { - auto s = op->operation->SetAttrTypeList( + auto s = tensorflow::unwrap(op)->SetAttrTypeList( attr_name, reinterpret_cast(values), num_values); if (!s.ok()) { @@ -1305,7 +1301,8 @@ void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, const unsigned char* values, int num_values) { - auto s = op->operation->SetAttrBoolList(attr_name, values, num_values); + auto s = + tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1314,19 +1311,14 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, const int64_t** dims, const int* num_dims, int num_values, TF_Status* out_status) { - out_status->status = - op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values); + out_status->status = tensorflow::unwrap(op)->SetAttrShapeList( + attr_name, dims, num_dims, num_values); } void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, const TFE_Op** value, int num_values) { - absl::FixedArray values( - num_values); - for (int i = 0; i < num_values; ++i) { - values[i] = value[i]->operation; - } - auto s = op->operation->SetAttrFunctionList(attr_name, - {values.data(), values.size()}); + auto s = tensorflow::unwrap(op)->SetAttrFunctionList( + attr_name, {tensorflow::unwrap(value), static_cast(num_values)}); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1341,12 +1333,13 @@ void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name, tensorflow::errors::InvalidArgument("Unparseable AttrValue proto"); return; } - if (op == nullptr || op->operation == nullptr) { + if (op == nullptr) { status->status = tensorflow::errors::InvalidArgument( "Got a null or uninitialized `op` argument"); return; } - tensorflow::EagerOperation* operation = OperationFromInterface(op->operation); + tensorflow::EagerOperation* operation = + OperationFromInterface(tensorflow::unwrap(const_cast(op))); operation->MutableAttrs()->Set(attr_name, attr_value); } @@ -1354,7 +1347,7 @@ TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op, const char* input_name, TF_Status* status) { int ret = -1; - status->status = op->operation->InputLength(input_name, &ret); + status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret); return ret; } @@ -1362,36 +1355,29 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, const char* output_name, TF_Status* status) { int ret = -1; - status->status = op->operation->OutputLength(output_name, &ret); + status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret); return ret; } void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { - absl::FixedArray handles( - *num_retvals); - status->status = op->operation->Execute(absl::MakeSpan(handles), num_retvals); - if (!status->status.ok()) { - return; - } - for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = new TFE_TensorHandle{handles[i]}; - } + status->status = tensorflow::unwrap(op)->Execute( + absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals); } TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - auto* result = ctx->context->CopyTensorHandleToDevice(h->handle, device_name, - &status->status); + auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice( + tensorflow::unwrap(h), device_name, &status->status); if (status->status.ok()) { - return new TFE_TensorHandle{result}; + return tensorflow::wrap(result); } return nullptr; } @@ -1406,39 +1392,39 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, return; } tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->AddFunctionDef(function->fdef); } void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, TF_Status* status) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->RemoveFunction(name); } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return context->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetShouldStoreGraphs(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetShouldStoreGraphs(false); } @@ -1446,13 +1432,13 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, TF_Status* status) { - return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(t)}; + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t)); } void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->Executor().WaitForAllPendingNodes(); if (!status->status.ok()) return; tensorflow::mutex_lock ml(*context->MetadataMu()); @@ -1475,20 +1461,21 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, void TFE_ContextStartStep(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->StartStep(); } void TFE_ContextEndStep(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->EndStep(); } void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { tensorflow::AttrValueMap m; attrs->attributes->FillAttrValueMap(&m); - tensorflow::EagerOperation* operation = OperationFromInterface(op->operation); + tensorflow::EagerOperation* operation = + OperationFromInterface(tensorflow::unwrap(op)); tensorflow::AttrBuilder* destination = operation->MutableAttrs(); for (const auto& attribute : m) { destination->Set(attribute.first, attribute.second); @@ -1576,33 +1563,34 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { const string& name() override { return name_; } tensorflow::Status CopyTensorToDevice( - tensorflow::TensorHandle* tensor, + tensorflow::TensorHandle* handle, tensorflow::TensorHandle** result) override { - tensor->Ref(); - TFE_TensorHandle tensor_handle{tensor}; + handle->Ref(); TF_Status status; - TFE_TensorHandle* result_handle = - device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_); - tensor_handle.handle->Release(); + TFE_TensorHandle* result_handle = device_.copy_tensor_to_device( + context_, tensorflow::wrap(handle), &status, info_); + handle->Release(); if (!status.status.ok()) return status.status; - *result = tensorflow::TensorHandleFromInterface(result_handle->handle); + *result = tensorflow::TensorHandleFromInterface( + tensorflow::unwrap(result_handle)); (*result)->Ref(); TFE_DeleteTensorHandle(result_handle); return status.status; } tensorflow::Status CopyTensorFromDevice( - tensorflow::TensorHandle* tensor, + tensorflow::TensorHandle* handle, const tensorflow::string& target_device_name, tensorflow::TensorHandle** result) override { TF_Status status; - tensor->Ref(); - TFE_TensorHandle tensor_handle{tensor}; + handle->Ref(); TFE_TensorHandle* result_handle = device_.copy_tensor_from_device( - context_, &tensor_handle, target_device_name.c_str(), &status, info_); - tensor_handle.handle->Release(); + context_, tensorflow::wrap(handle), target_device_name.c_str(), &status, + info_); + handle->Release(); if (!status.status.ok()) return status.status; - *result = tensorflow::TensorHandleFromInterface(result_handle->handle); + *result = tensorflow::TensorHandleFromInterface( + tensorflow::unwrap(result_handle)); (*result)->Ref(); TFE_DeleteTensorHandle(result_handle); return status.status; @@ -1615,7 +1603,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { inputs.reserve(op->Inputs().size()); for (int i = 0; i < op->Inputs().size(); ++i) { op->Inputs()[i]->Ref(); - inputs.push_back(new TFE_TensorHandle{op->Inputs()[i]}); + inputs.push_back(tensorflow::wrap(op->Inputs()[i])); } std::vector outputs(*num_retvals); TF_Status status; @@ -1624,7 +1612,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { &attributes, num_retvals, outputs.data(), &status, info_); if (status.status.ok()) { for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle); + retvals[i] = tensorflow::TensorHandleFromInterface( + tensorflow::unwrap(outputs[i])); retvals[i]->Ref(); TFE_DeleteTensorHandle(outputs[i]); } @@ -1652,7 +1641,7 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, auto custom_device = std::make_unique(ctx, device, device_info, device_name); tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->RegisterCustomDevice(device_name, std::move(custom_device)); } diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index d3d1126fd9a..6827021455b 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -57,7 +57,8 @@ extern "C" { TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( TFE_TensorHandle* h, TF_Status* status) { - tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle); + tensorflow::TensorHandle* handle = + TensorHandleFromInterface(tensorflow::unwrap(h)); const tensorflow::Tensor* tensor; status->status = handle->Tensor(&tensor); if (!status->status.ok()) { diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index b43af710c04..820650e315f 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -19,6 +19,9 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" @@ -34,9 +37,10 @@ using tensorflow::string; void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { if (op_to_reset) { - op_to_reset->operation->Clear(); - status->status = - op_to_reset->operation->Reset(op_or_function_name, raw_device_name); + tensorflow::AbstractOperationInterface* op = + tensorflow::unwrap(op_to_reset); + op->Clear(); + status->status = op->Reset(op_or_function_name, raw_device_name); } else { TF_SetStatus(status, TF_INVALID_ARGUMENT, "op_to_reset should not be nullptr"); @@ -45,13 +49,13 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetShouldStoreGraphs(false); } @@ -483,7 +487,7 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options, void TFE_ContextSetThreadLocalMirroringPolicy( TFE_Context* ctx, TFE_ContextMirroringPolicy policy) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetThreadLocalMirroringPolicy( static_cast(policy)); } @@ -494,7 +498,7 @@ void TFE_ContextSetThreadLocalMirroringPolicy( extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return static_cast(context->GetMirroringPolicy()); } @@ -530,7 +534,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op, TFE_CancellationManager* cancellation_manager, TF_Status* status) { tensorflow::EagerOperation* operation = - tensorflow::OperationFromInterface(op->operation); + tensorflow::OperationFromInterface(tensorflow::unwrap(op)); operation->SetCancellationManager( &cancellation_manager->cancellation_manager); status->status = tensorflow::Status::OK(); @@ -557,19 +561,19 @@ void TFE_ExecutorClearError(TFE_Executor* executor) { void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetExecutorForThread(executor->executor()); } TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return new TFE_Executor(&context->Executor()); } void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); auto address_space = tensorflow::DeviceNameUtils::AddressSpace( context->HostCPU()->parsed_name()); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); @@ -585,7 +589,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, TF_Buffer* buf, TF_Status* status) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); auto* function_def = context->FindFunctionDef(function_name); if (function_def == nullptr) { status->status = tensorflow::errors::NotFound( @@ -611,13 +615,14 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype, dimvec[i] = static_cast(dims[i]); } - if (ctx == nullptr || ctx->context == nullptr) { + if (ctx == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid Context"); return nullptr; } - tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor( - static_cast(dtype), dimvec); + tensorflow::AbstractTensorInterface* t = + tensorflow::unwrap(ctx)->CreateTensor( + static_cast(dtype), dimvec); if (t == nullptr) { status->status = @@ -630,5 +635,6 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype, TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t, TF_Status* status) { - return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)}; + return tensorflow::wrap( + tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor)); } diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 39b767d53c0..4d9be0c2501 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -19,13 +19,10 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export -#include "tensorflow/c/eager/tfe_context_internal.h" // IWYU pragma: export #include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export #include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export #include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export -#include "tensorflow/c/eager/tfe_op_internal.h" // IWYU pragma: export #include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export -#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" // IWYU pragma: export // TODO(b/154564140): Move this to its own header. This requires splitting // c_api_experimental.h diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index 4f1ef7847d5..7c6836af69b 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/platform/casts.h" @@ -233,7 +234,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) { ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU")); TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle); + auto remote_arg = + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); // The input handles should never change since they have been mirrored. ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr)); } @@ -245,7 +247,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) { // TODO(gjn): Add support for waiting on async local mirrors if (!remote && !async) { - auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle); + auto remote_arg = + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); // The input handles should never change since they have been mirrored. ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr)); } diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 8abbe1382ad..0e4183dad16 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/function.pb.h" @@ -416,8 +418,10 @@ void TensorHandleSilentCopy(bool async, hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle); - auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle); + auto cpu_arg = + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu)); + auto gpu_arg = + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu)); auto gpu_device = absl::get(gpu_arg->device()); ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device)); @@ -1346,7 +1350,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) { tensorflow::AttrValueMap attr_values; tensorflow::EagerOperation* operation = - tensorflow::OperationFromInterface(op->operation); + tensorflow::OperationFromInterface(tensorflow::unwrap(op)); operation->Attrs().FillAttrValueMap(&attr_values); return attr_values; } @@ -1482,10 +1486,10 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) { TFE_TensorHandle* inputs[] = {input1, input2}; TFE_OpAddInput(concatOp, dim, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - CHECK(concatOp->operation->OpDef()); + CHECK(tensorflow::unwrap(concatOp)->OpDef()); TFE_OpAddInput(concatOp, inputs[0], status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_FALSE(concatOp->operation->OpDef()) + EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef()) << "Inference context is still present"; TFE_OpAddInput(concatOp, inputs[1], status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -1590,7 +1594,7 @@ TEST(CAPI, TestTFE_OpAddAttrs) { // There is currently no API to fetch attributes from an operation, fetching // happens only as an implementation detail of custom devices. tensorflow::EagerOperation* operation = - OperationFromInterface(var_op->operation); + OperationFromInterface(tensorflow::unwrap(var_op)); TFE_OpAttrs attributes{&operation->Attrs()}; TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status); @@ -1606,7 +1610,7 @@ TEST(CAPI, TestTFE_OpAddAttrs) { tensorflow::AttrValueMap attr_values; tensorflow::EagerOperation* op = - tensorflow::OperationFromInterface(copy_op->operation); + tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op)); op->Attrs().FillAttrValueMap(&attr_values); EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type()); @@ -1630,7 +1634,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) { // There is currently no API to fetch attributes from an operation, fetching // happens only as an implementation detail of custom devices. tensorflow::EagerOperation* operation = - OperationFromInterface(var_op->operation); + OperationFromInterface(tensorflow::unwrap(var_op)); TFE_OpAttrs attributes{&operation->Attrs()}; TF_Buffer* serialized_attr_values = TF_NewBuffer(); @@ -1657,7 +1661,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) { tensorflow::AttrValueMap attr_values; tensorflow::EagerOperation* op = - tensorflow::OperationFromInterface(var_op_2->operation); + tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2)); op->Attrs().FillAttrValueMap(&attr_values); EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type()); diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 0ac38cee5ff..a0d6fe914c2 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/c/eager/dlpack.h" #include "include/dlpack/dlpack.h" // from @dlpack -#include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_reference.h" @@ -41,12 +43,12 @@ struct TfDlManagedTensorCtx { // Gets tensor from eager tensor handle. const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } tensorflow::TensorHandle* handle = - tensorflow::TensorHandleFromInterface(h->handle); + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h)); if (handle->Type() != TensorHandle::LOCAL) { status->status = tensorflow::errors::InvalidArgument( "DLPack doesn't support ", handle->TypeString(), " tensor"); @@ -107,7 +109,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) { // Gets DLPack's DLContext from eager tensor handle. DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) { DLContext ctx; - const char* device_name = h->handle->DeviceName(&status->status); + const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status); DeviceNameUtils::ParsedName parsed_name; tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); std::string device_type = parsed_name.type; diff --git a/tensorflow/c/eager/tfe_context_internal.h b/tensorflow/c/eager/tfe_context_internal.h index 4c2e650c879..1d29bee9ee3 100644 --- a/tensorflow/c/eager/tfe_context_internal.h +++ b/tensorflow/c/eager/tfe_context_internal.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/eager/context_interface.h" // Wraps a pointer to a context implementation. @@ -23,8 +24,12 @@ limitations under the License. // interface cannot destruct the underlying context object. Instead, call // TFE_DeleteContext who calls Release() on the context pointer and deletes // the TFE_Context structure. -struct TFE_Context { - tensorflow::AbstractContextInterface* context; -}; +typedef struct TFE_Context TFE_Context; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context); + +} // namespace tensorflow #endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ diff --git a/tensorflow/c/eager/tfe_op_attrs_internal.h b/tensorflow/c/eager/tfe_op_attrs_internal.h index 1d745c0ce4f..935d7d520e5 100644 --- a/tensorflow/c/eager/tfe_op_attrs_internal.h +++ b/tensorflow/c/eager/tfe_op_attrs_internal.h @@ -23,14 +23,15 @@ limitations under the License. #include #include -#include "tensorflow/c/eager/tfe_context_internal.h" -#include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/framework/attr_value.pb.h" // An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways // that sometimes do not require serialization. +typedef struct TFE_Context TFE_Context; +typedef struct TFE_Op TFE_Op; + struct TFE_OpAttrs { explicit TFE_OpAttrs() : attributes(nullptr) {} diff --git a/tensorflow/c/eager/tfe_op_internal.h b/tensorflow/c/eager/tfe_op_internal.h index b9292e2e27b..6ca7f741d16 100644 --- a/tensorflow/c/eager/tfe_op_internal.h +++ b/tensorflow/c/eager/tfe_op_internal.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/eager/operation_interface.h" // Wraps a pointer to an operation implementation. @@ -23,8 +24,13 @@ limitations under the License. // interface cannot destruct the underlying operation object. Instead, call // TFE_DeleteOp who calls Release() on the operation pointer and deletes // the TFE_Op structure. -struct TFE_Op { - tensorflow::AbstractOperationInterface* operation; -}; +typedef struct TFE_Op TFE_Op; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*); + +} // namespace tensorflow #endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ diff --git a/tensorflow/c/eager/tfe_tensorhandle_internal.h b/tensorflow/c/eager/tfe_tensorhandle_internal.h index 39843c0b959..543e5f1d932 100644 --- a/tensorflow/c/eager/tfe_tensorhandle_internal.h +++ b/tensorflow/c/eager/tfe_tensorhandle_internal.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/eager/tensor_handle_interface.h" // Wraps a pointer to a tensor handle implementation. @@ -23,8 +24,15 @@ limitations under the License. // interface cannot destruct the underlying handle object. Instead, call // TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes // the TFE_TensorHandle structure. -struct TFE_TensorHandle { - tensorflow::AbstractTensorHandleInterface* handle; -}; +typedef struct TFE_TensorHandle TFE_TensorHandle; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface, + TFE_TensorHandle); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*, + TFE_TensorHandle*); + +} // namespace tensorflow #endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 816a7b3110a..7a694f4f803 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -22,13 +22,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -cc_library( - name = "conversion_macros", - hdrs = [ - "conversion_macros.h", - ], -) - cc_library( name = "concrete_function", srcs = [ @@ -51,6 +44,7 @@ cc_library( "//tensorflow/c:c_api_macros", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:tfe_op_internal", "//tensorflow/c/experimental/saved_model/core:concrete_function", "//tensorflow/c/experimental/saved_model/core:function_metadata", ], @@ -93,7 +87,7 @@ cc_library( "concrete_function_type.h", ], deps = [ - ":conversion_macros", + "//tensorflow/c:conversion_macros", "//tensorflow/c/experimental/saved_model/core:concrete_function", ], ) @@ -123,7 +117,7 @@ cc_library( "function_metadata_type.h", ], deps = [ - ":conversion_macros", + "//tensorflow/c:conversion_macros", "//tensorflow/c/experimental/saved_model/core:function_metadata", ], ) diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index da964900420..4884f9e2e97 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" @@ -24,7 +25,8 @@ limitations under the License. extern "C" { TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) { - return tensorflow::wrap(&tensorflow::unwrap(func)->GetFunctionMetadata()); + return tensorflow::wrap(const_cast( + &tensorflow::unwrap(func)->GetFunctionMetadata())); } TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) { @@ -34,7 +36,7 @@ TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) { } TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) { - return new TFE_Op{tensorflow::unwrap(func)->GetCallOp()}; + return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp()); } } // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function_type.h b/tensorflow/c/experimental/saved_model/internal/concrete_function_type.h index 37973373191..bc36b0c6f08 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function_type.h +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function_type.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_ +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" -#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h" // Internal structures used by the SavedModel C API. These are likely to change // and should not be depended on. diff --git a/tensorflow/c/experimental/saved_model/internal/conversion_macros.h b/tensorflow/c/experimental/saved_model/internal/conversion_macros.h deleted file mode 100644 index 73875f02441..00000000000 --- a/tensorflow/c/experimental/saved_model/internal/conversion_macros.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_ -#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_ - -#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \ - inline cpp_impl *unwrap(wrapper *w) { \ - return reinterpret_cast(w); \ - } \ - \ - inline wrapper *wrap(const cpp_impl *i) { \ - return reinterpret_cast(const_cast(i)); \ - } - -#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/function_metadata_type.h b/tensorflow/c/experimental/saved_model/internal/function_metadata_type.h index ab89cf2d7be..40f05f9117d 100644 --- a/tensorflow/c/experimental/saved_model/internal/function_metadata_type.h +++ b/tensorflow/c/experimental/saved_model/internal/function_metadata_type.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_ +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" -#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h" typedef struct TF_FunctionMetadata TF_FunctionMetadata; diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index dfe0b834c57..cce1b27d9ad 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -36,7 +36,8 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx, std::string saved_model_dir(dirname); std::unique_ptr result = - ctx->context->LoadSavedModelAPI(dirname, absl::nullopt, &status->status); + tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt, + &status->status); if (!status->status.ok()) { return nullptr; } @@ -54,8 +55,8 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, } std::unique_ptr result = - ctx->context->LoadSavedModelAPI(dirname, std::move(tagset), - &status->status); + tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset), + &status->status); if (!status->status.ok()) { return nullptr; } diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 3c37c028279..6f985f1c2f4 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -903,7 +903,8 @@ cc_library( ":safe_ptr", "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:tfe_context_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -1007,7 +1008,8 @@ cc_library( ":safe_ptr", "//tensorflow/c:tensor_interface", "//tensorflow/c:tf_tensor_internal", - "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:tfe_context_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/python_runtime:headers", diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 9d5d3916fc5..c08cb8cc1c3 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -40,6 +40,9 @@ cc_library( "//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:dlpack", "//tensorflow/c/eager:tape", + "//tensorflow/c/eager:tfe_context_internal", + "//tensorflow/c/eager:tfe_op_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/python/eager/pywrap_tensor_conversion.cc b/tensorflow/python/eager/pywrap_tensor_conversion.cc index 7b192214f18..041ddf4ec53 100644 --- a/tensorflow/python/eager/pywrap_tensor_conversion.cc +++ b/tensorflow/python/eager/pywrap_tensor_conversion.cc @@ -17,7 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/hash/hash.h" -#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/platform/logging.h" @@ -41,7 +41,7 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup( PyObject* value, tensorflow::DataType dtype, absl::string_view device_name) const { CHECK_NOTNULL(value); - const auto& it = cache.find(Key{PyObjectPtr{value}, dtype, device_name}); + const auto it = cache.find(Key{PyObjectPtr{value}, dtype, device_name}); if (it == cache.end()) { scalar_cache_misses->GetCell()->IncrementBy(1); return nullptr; @@ -49,7 +49,7 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup( scalar_cache_hits->GetCell()->IncrementBy(1); auto* h = it->second; - return new TFE_TensorHandle{h->handle->Copy()}; + return tensorflow::wrap(tensorflow::unwrap(h)->Copy()); } void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype, @@ -57,7 +57,7 @@ void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype, TFE_TensorHandle* h) { Py_INCREF(value); cache.emplace(Key{PyObjectPtr{value}, dtype, device_name}, - new TFE_TensorHandle{h->handle->Copy()}); + tensorflow::wrap(tensorflow::unwrap(h)->Copy())); } void TFE_TensorHandleCache::Clear() { diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index b4a16f415e8..92a0a200e3d 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -379,7 +379,7 @@ void TFE_Py_EnableInteractivePythonLogging(); // Py_None. // // This function is not thread-safe. -PyObject* TFE_Py_SetEagerContext(PyObject* python_context); +PyObject* TFE_Py_SetEagerContext(PyObject* py_context); // Returns the current eager Context object (defined in eager/context.py) // that was last set using TFE_Py_SetEagerContext. diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 3100c1589b1..316f91ec88b 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -24,6 +24,9 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/tape.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -80,9 +83,10 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { auto op = ReleaseThreadLocalOp(ctx); if (!op) { - op.reset(new TFE_Op{ctx->context->CreateOperation()}); + op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation())); } - status->status = op->operation->Reset(op_or_function_name, raw_device_name); + status->status = + tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name); if (!status->status.ok()) { op.reset(); } @@ -91,7 +95,7 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, void ReturnOp(TFE_Context* ctx, TFE_Op* op) { if (op) { - op->operation->Clear(); + tensorflow::unwrap(op)->Clear(); thread_local_eager_operation_map[ctx].reset(op); } } @@ -1500,7 +1504,7 @@ static PyTypeObject TFE_Py_Tape_Type = { sizeof(TFE_Py_Tape), /* tp_basicsize */ 0, /* tp_itemsize */ &TFE_Py_Tape_Delete, /* tp_dealloc */ - 0, /* tp_print */ + nullptr, /* tp_print */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ nullptr, /* tp_reserved */ @@ -1538,7 +1542,7 @@ static PyTypeObject TFE_Py_ForwardAccumulator_Type = { sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */ 0, /* tp_itemsize */ &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */ - 0, /* tp_print */ + nullptr, /* tp_print */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ nullptr, /* tp_reserved */ @@ -1573,7 +1577,7 @@ static PyTypeObject TFE_Py_VariableWatcher_Type = { sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */ 0, /* tp_itemsize */ &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */ - 0, /* tp_print */ + nullptr, /* tp_print */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ nullptr, /* tp_reserved */ @@ -1990,21 +1994,22 @@ bool ListContainsNone(PyObject* list) { static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { - TFE_TensorHandle* t = EagerTensor_Handle(tensor); + tensorflow::AbstractTensorHandleInterface* handle = + tensorflow::unwrap(EagerTensor_Handle(tensor)); tensorflow::int64 id = PyEagerTensor_ID(tensor); tensorflow::DataType dtype = - static_cast(t->handle->DataType()); + static_cast(handle->DataType()); if (dtype == tensorflow::DT_VARIANT) { return PyTapeTensor(id, dtype, tensor); } tensorflow::TensorShape tensor_shape; int num_dims; - tensorflow::Status status = t->handle->NumDims(&num_dims); + tensorflow::Status status = handle->NumDims(&num_dims); if (status.ok()) { for (int i = 0; i < num_dims; ++i) { tensorflow::int64 dim_size; - status = t->handle->Dim(i, &dim_size); + status = handle->Dim(i, &dim_size); if (!status.ok()) break; tensor_shape.AddDim(dim_size); } @@ -3511,7 +3516,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) { return nullptr; } - const tensorflow::OpDef* op_def = op->operation->OpDef(); + const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef(); if (op_def == nullptr) return nullptr; if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) { @@ -3850,14 +3855,15 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, bool include_tensor_ranks_only, EncodeResult* result) { if (EagerTensor_CheckExact(arg)) { - TFE_TensorHandle* t = EagerTensor_Handle(arg); + tensorflow::AbstractTensorHandleInterface* handle = + tensorflow::unwrap(EagerTensor_Handle(arg)); absl::StrAppend(&result->str, kDType, - static_cast(t->handle->DataType())); + static_cast(handle->DataType())); absl::StrAppend(&result->str, kShape); int num_dims; - tensorflow::Status status = t->handle->NumDims(&num_dims); + tensorflow::Status status = handle->NumDims(&num_dims); if (!status.ok()) return status; if (include_tensor_ranks_only) { @@ -3865,7 +3871,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, } else { for (int i = 0; i < num_dims; ++i) { tensorflow::int64 dim_size; - status = t->handle->Dim(i, &dim_size); + status = handle->Dim(i, &dim_size); if (!status.ok()) return status; absl::StrAppend(&result->str, dim_size, kShapeDelim); } diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 80b9863ea48..a3c83bb5d59 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -26,7 +26,8 @@ limitations under the License. #include "numpy/arrayobject.h" #include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" @@ -95,8 +96,8 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) { if (call->eager) { Tensor t = call->ins[i]; arg = EagerTensorFromHandle( - new TFE_TensorHandle{TensorHandle::CreateLocalHandle( - std::move(t), ctx->CanonicalDevice(device), nullptr, ctx)}); + tensorflow::wrap(TensorHandle::CreateLocalHandle( + std::move(t), ctx->CanonicalDevice(device), nullptr, ctx))); if (arg == nullptr) { Py_DECREF(lst); return errors::Internal("Unable to procure EagerTensor from Tensor."); @@ -146,7 +147,7 @@ tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, const Device* expected_device, const Tensor** output_tensor) { tensorflow::TensorHandle* handle = tensorflow::TensorHandleFromInterface( - EagerTensor_Handle(eager_tensor)->handle); + tensorflow::unwrap(EagerTensor_Handle(eager_tensor))); if (VariantDeviceIsCustom(handle->device())) { return errors::Unimplemented( "Custom devices are currently not supported with PyFuncs."); @@ -196,7 +197,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { TFE_Context* ctx = reinterpret_cast(PyCapsule_GetPointer( PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); CHECK_NE(ctx, nullptr); - EagerContext* context = ContextFromInterface(ctx->context); + EagerContext* context = ContextFromInterface(tensorflow::unwrap(ctx)); TF_RETURN_IF_ERROR(MakeArgTuple(call, context, &args)); new_executor.reset(new EagerExecutor(call->eager_async)); old_executor = &context->Executor(); @@ -236,7 +237,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { if (new_executor != nullptr) { TFE_Context* ctx = reinterpret_cast(PyCapsule_GetPointer( PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); - EagerContext* context = ContextFromInterface(ctx->context); + EagerContext* context = ContextFromInterface(tensorflow::unwrap(ctx)); s.Update(new_executor->WaitForAllPendingNodes()); context->SetExecutorForThread(old_executor); } diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index f05afeb22e5..ecf4a92f0e7 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/python/lib/core/py_seq_tensor.h" -#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tensor_interface.h" #include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/framework/tensor.h" @@ -305,7 +306,7 @@ struct Converter { } } } - *h = new TFE_TensorHandle{ctx->context->CreateLocalHandle(t)}; + *h = tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(t)); t->Release(); return Status::OK(); } @@ -316,12 +317,12 @@ struct Converter { template <> struct ConverterTraits { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, int64 value) { - return ctx->context->CreateInt64Scalar(value); + return tensorflow::unwrap(ctx)->CreateInt64Scalar(value); } static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateTensor(DT_INT64, dim_sizes); + return tensorflow::unwrap(ctx)->CreateTensor(DT_INT64, dim_sizes); } static const char* ConvertScalar(PyObject* v, int64* out) { @@ -356,12 +357,12 @@ typedef Converter Int64Converter; template <> struct ConverterTraits { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, uint64 value) { - return ctx->context->CreateUint64Scalar(value); + return tensorflow::unwrap(ctx)->CreateUint64Scalar(value); } static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateTensor(DT_UINT64, dim_sizes); + return tensorflow::unwrap(ctx)->CreateTensor(DT_UINT64, dim_sizes); } static const char* ConvertScalar(PyObject* v, uint64* out) { @@ -393,12 +394,12 @@ typedef Converter UInt64Converter; template <> struct ConverterTraits { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, int32 value) { - return ctx->context->CreateInt32Scalar(value); + return tensorflow::unwrap(ctx)->CreateInt32Scalar(value); } static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateTensor(DT_INT32, dim_sizes); + return tensorflow::unwrap(ctx)->CreateTensor(DT_INT32, dim_sizes); } static const char* ConvertScalar(PyObject* v, int32* out) { @@ -500,12 +501,12 @@ static const char* ConvertOneFloat(PyObject* v, T* out) { template <> struct ConverterTraits { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, float value) { - return ctx->context->CreateFloatScalar(value); + return tensorflow::unwrap(ctx)->CreateFloatScalar(value); } static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateTensor(DT_FLOAT, dim_sizes); + return tensorflow::unwrap(ctx)->CreateTensor(DT_FLOAT, dim_sizes); } static const char* ConvertScalar(PyObject* v, float* out) { @@ -516,12 +517,12 @@ struct ConverterTraits { template <> struct ConverterTraits { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, double value) { - return ctx->context->CreateDoubleScalar(value); + return tensorflow::unwrap(ctx)->CreateDoubleScalar(value); } static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateTensor(DT_DOUBLE, dim_sizes); + return tensorflow::unwrap(ctx)->CreateTensor(DT_DOUBLE, dim_sizes); } static const char* ConvertScalar(PyObject* v, double* out) { @@ -536,12 +537,12 @@ template <> struct ConverterTraits { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, Eigen::half value) { - return ctx->context->CreateHalfScalar(value); + return tensorflow::unwrap(ctx)->CreateHalfScalar(value); } static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateTensor(DT_HALF, dim_sizes); + return tensorflow::unwrap(ctx)->CreateTensor(DT_HALF, dim_sizes); } static const char* ConvertScalar(PyObject* v, Eigen::half* out) { @@ -557,12 +558,12 @@ template <> struct ConverterTraits { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, tstring value) { - return ctx->context->CreateStringScalar(value); + return tensorflow::unwrap(ctx)->CreateStringScalar(value); } static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateTensor(DT_STRING, dim_sizes); + return tensorflow::unwrap(ctx)->CreateTensor(DT_STRING, dim_sizes); } static const char* ConvertScalar(PyObject* v, tstring* out) { @@ -624,12 +625,12 @@ template <> struct ConverterTraits { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, complex128 value) { - return ctx->context->CreateComplex128Scalar(value); + return tensorflow::unwrap(ctx)->CreateComplex128Scalar(value); } static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateTensor(DT_COMPLEX128, dim_sizes); + return tensorflow::unwrap(ctx)->CreateTensor(DT_COMPLEX128, dim_sizes); } static const char* ConvertScalar(PyObject* v, complex128* out) { @@ -652,12 +653,12 @@ typedef Converter Complex128Converter; template <> struct ConverterTraits { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, bool value) { - return ctx->context->CreateBoolScalar(value); + return tensorflow::unwrap(ctx)->CreateBoolScalar(value); } static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateTensor(DT_BOOL, dim_sizes); + return tensorflow::unwrap(ctx)->CreateTensor(DT_BOOL, dim_sizes); } static const char* ConvertScalar(PyObject* v, bool* out) { @@ -692,7 +693,7 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { } TensorInterface t(std::move(tensor)); - return new TFE_TensorHandle{ctx->context->CreateLocalHandle(&t)}; + return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t)); } } // namespace @@ -877,7 +878,7 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, TensorShape(state.inferred_shape)); TensorInterface t(std::move(tensor)); - return new TFE_TensorHandle{ctx->context->CreateLocalHandle(&t)}; + return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t)); } default: diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 26683c3d433..ec54efa61cf 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -270,7 +270,6 @@ static py::object TFE_ClearScalarCache() { // are only assigning this to functions that return opaque types. PYBIND11_MODULE(_pywrap_tfe, m) { - py::class_ TFE_Context_class(m, "TFE_Context"); py::class_ TFE_Executor_class(m, "TFE_Executor"); py::class_ TFE_ContextOptions_class(m, "TFE_ContextOptions"); @@ -760,7 +759,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) { m.def("TFE_ContextStartStep", [](py::handle& o) { TFE_ContextStartStep(tensorflow::InputTFE_Context(o.ptr())); }); - m.def("TFE_ContextEndStep", &TFE_ContextEndStep); + m.def("TFE_ContextEndStep", [](py::handle& o) { + TFE_ContextEndStep(tensorflow::InputTFE_Context(o.ptr())); + }); m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) { return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr())); });