Avoid pointer indirection in handle, context & op

Since the struct lifetime is bound to the wrapped pointer this is fine.

PiperOrigin-RevId: 308941521
Change-Id: I0604fff4fcba6a03cc4a2242ab9f182fbfbf8bae
This commit is contained in:
Gaurav Jain 2020-04-28 18:54:46 -07:00 committed by TensorFlower Gardener
parent 510f0f9a6c
commit 9ec6997dfb
29 changed files with 352 additions and 301 deletions

View File

@ -334,6 +334,9 @@ tf_cuda_library(
":checkpoint_reader",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/compiler/jit:flags",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@ -729,3 +732,11 @@ tf_cuda_library(
],
alwayslink = 1,
)
cc_library(
name = "conversion_macros",
hdrs = [
"conversion_macros.h",
],
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -21,6 +21,9 @@ limitations under the License.
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -686,8 +689,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
status->status = tensorflow::Status::OK();
return new TFE_TensorHandle{
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
}
namespace {
@ -708,7 +710,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
// New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
@ -822,14 +824,13 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items;
NodeDef node_def;
node_def.set_name(tfe_op->operation->Name());
node_def.set_op(tfe_op->operation->Name());
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
node_def.set_name(op->Name());
node_def.set_op(op->Name());
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
OperationFromInterface(tfe_op->operation)
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());
OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
const tensorflow::OpRegistrationData* op_reg_data;
status->status =

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
#define TENSORFLOW_C_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline const cpp_impl *unwrap(const wrapper *w) { \
return reinterpret_cast<const cpp_impl *>(w); \
} \
\
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); }
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_

View File

@ -50,7 +50,6 @@ tf_cuda_library(
":tfe_tensor_debug_info_internal",
":tfe_tensorhandle_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"//tensorflow/c:c_api",
@ -110,13 +109,10 @@ filegroup(
"operation_interface.h",
"tensor_handle_interface.h",
"tfe_cancellation_manager_internal.h",
"tfe_context_internal.h",
"tfe_executor_internal.h",
"tfe_monitoring_internal.h",
"tfe_op_attrs_internal.h",
"tfe_op_internal.h",
"tfe_tensor_debug_info_internal.h",
"tfe_tensorhandle_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
@ -205,6 +201,7 @@ cc_library(
],
deps = [
":context_interface",
"//tensorflow/c:conversion_macros",
],
)
@ -249,8 +246,6 @@ cc_library(
"//tensorflow:internal",
],
deps = [
":tfe_context_internal",
":tfe_op_internal",
"//tensorflow/c:tf_status",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
@ -265,6 +260,7 @@ cc_library(
],
deps = [
":operation_interface",
"//tensorflow/c:conversion_macros",
],
)
@ -287,6 +283,7 @@ cc_library(
],
deps = [
":tensor_handle_interface",
"//tensorflow/c:conversion_macros",
],
)
@ -327,6 +324,8 @@ tf_cuda_cc_test(
":c_api_experimental",
":c_api_internal",
":c_api_test_util",
":tfe_op_internal",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@ -351,6 +350,7 @@ tf_cuda_cc_test(
":c_api_experimental",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -384,6 +384,9 @@ tf_cuda_library(
"//conditions:default": [
":c_api",
":c_api_internal",
":tfe_context_internal",
":tfe_op_internal",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
@ -548,8 +551,9 @@ cc_library(
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
":tfe_tensorhandle_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",

View File

@ -26,7 +26,6 @@ limitations under the License.
// clang-format on
#include "absl/algorithm/container.h"
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
@ -34,6 +33,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#ifdef PLATFORM_GOOGLE
#include "tensorflow/c/eager/c_api_tfrt.h"
@ -298,7 +300,7 @@ tensorflow::Status CreateRemoteContexts(
std::vector<bool> filtered_device_mask;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
@ -383,7 +385,7 @@ tensorflow::Status UpdateRemoteContexts(
std::vector<bool> filtered_device_mask;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
@ -464,7 +466,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -684,7 +686,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE
status->status = tensorflow::Status::OK();
return new TFE_Context{new tfrt::ContextInterface()};
return tensorflow::wrap(new tfrt::ContextInterface());
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;
@ -701,14 +703,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context{new tensorflow::EagerContext(
return tensorflow::wrap(new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator())};
tensorflow::GetDefaultCustomKernelCreator()));
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -719,14 +721,14 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context{new tensorflow::EagerContext(
return tensorflow::wrap(new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator())};
tensorflow::GetDefaultCustomKernelCreator()));
}
void TFE_DeleteContext(TFE_Context* ctx) {
@ -734,22 +736,19 @@ void TFE_DeleteContext(TFE_Context* ctx) {
return;
}
// context->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting.
ctx->context->Release();
delete ctx;
// ctx->RefCountIsOne() should be true here.
tensorflow::unwrap(ctx)->Release();
}
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* l = new TF_DeviceList;
ctx->context->ListDevices(&l->response);
tensorflow::unwrap(ctx)->ListDevices(&l->response);
return l;
}
void TFE_ContextClearCaches(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->ClearCachesAndThreadExecutors();
}
@ -772,7 +771,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
if (server_def.has_cluster_device_filters()) {
const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) {
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
const string remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size());
@ -781,7 +780,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
}
const string remote_worker = remote_prefix + std::to_string(task_index);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status =
context->SetRemoteDeviceFilters(remote_worker, device_filters);
}
@ -803,7 +802,7 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer");
@ -833,7 +832,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
return false;
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(context->GetServer());
@ -889,7 +888,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->SyncExecutors();
#endif // !IS_MOBILE_PLATFORM
}
@ -897,7 +896,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
}
@ -908,7 +907,7 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextDevicePlacementPolicy>(
context->GetDevicePlacementPolicy());
}
@ -918,8 +917,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;
return new TFE_TensorHandle{
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
@ -927,84 +925,84 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
tensorflow::profiler::TraceMe activity(
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
if (h->handle) {
h->handle->Release();
if (h) {
tensorflow::unwrap(h)->Release();
}
delete h;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->handle->DataType());
return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
int num_dims = -1;
status->status = h->handle->NumDims(&num_dims);
status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
return num_dims;
}
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
int64 num_elements = -1;
status->status = h->handle->NumElements(&num_elements);
status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
return num_elements;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
int64 dim = -1;
status->status = h->handle->Dim(dim_index, &dim);
status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
return dim;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return h->handle->DeviceName(&status->status);
return tensorflow::unwrap(h)->DeviceName(&status->status);
}
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return h->handle->BackingDeviceName(&status->status);
return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
}
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return new TFE_TensorHandle{h->handle->Copy()};
return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::AbstractTensorInterface* t = h->handle->Resolve(&status->status);
tensorflow::AbstractTensorInterface* t =
tensorflow::unwrap(h)->Resolve(&status->status);
if (t == nullptr) {
return nullptr;
}
@ -1013,12 +1011,12 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
}
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle);
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (VariantDeviceIsCustom(handle->device())) {
const tensorflow::Tensor* t;
status->status = handle->Tensor(&t);
@ -1054,7 +1052,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device = nullptr;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->FindDeviceFromName(device_name, &device);
tensorflow::CustomDevice* custom_device = nullptr;
if (!status->status.ok()) {
@ -1080,11 +1078,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorShape(dimvec), buf);
buf->Unref();
if (custom_device == nullptr) {
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), device, device, context)};
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), device, device, context));
} else {
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context)};
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context));
}
}
@ -1093,12 +1091,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
// bytes of the memory pointed to by the device pointer returned above.
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return 0;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle);
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDeviceMemorySize may not be called on a ",
@ -1115,12 +1113,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
std::unique_ptr<TFE_Op> new_op(new TFE_Op{ctx->context->CreateOperation()});
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
tensorflow::AbstractOperationInterface* new_op =
tensorflow::unwrap(ctx)->CreateOperation();
status->status = new_op->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) {
new_op.reset();
new_op->Release();
new_op = nullptr;
}
return new_op.release();
return tensorflow::wrap(new_op);
}
void TFE_DeleteOp(TFE_Op* op) {
@ -1128,24 +1128,20 @@ void TFE_DeleteOp(TFE_Op* op) {
return;
}
if (op->operation) {
op->operation->Release();
}
delete op;
tensorflow::unwrap(op)->Release();
}
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = op->operation->SetDeviceName(device_name);
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
return op->operation->DeviceName().c_str();
return tensorflow::unwrap(op)->DeviceName().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Status s = op->operation->SetUseXla(enable);
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
if (!s.ok()) {
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
}
@ -1156,18 +1152,13 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
status->status = op->operation->AddInput(input->handle);
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
}
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
num_inputs);
for (int i = 0; i < num_inputs; ++i) {
handles[i] = inputs[i]->handle;
}
status->status =
op->operation->AddInputList({handles.data(), handles.size()});
status->status = tensorflow::unwrap(op)->AddInputList(
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1175,8 +1166,8 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
TF_AttrType ret = TF_ATTR_INT;
const tensorflow::AttrTypeMap* attr_types_;
bool is_function;
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
&attr_types_, &is_function);
status->status = tensorflow::AttrTypeMapForOp(
tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
if (!status->status.ok()) {
return ret;
}
@ -1202,7 +1193,7 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
size_t length) {
auto s = op->operation->SetAttrString(
auto s = tensorflow::unwrap(op)->SetAttrString(
attr_name, static_cast<const char*>(value), length);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
@ -1210,29 +1201,30 @@ void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
auto s = op->operation->SetAttrInt(attr_name, value);
auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
auto s = op->operation->SetAttrFloat(attr_name, value);
auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
(value == 0) ? false : true);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
auto s = op->operation->SetAttrType(attr_name,
static_cast<tensorflow::DataType>(value));
auto s = tensorflow::unwrap(op)->SetAttrType(
attr_name, static_cast<tensorflow::DataType>(value));
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1240,12 +1232,14 @@ void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
const int num_dims, TF_Status* out_status) {
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
out_status->status =
tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
}
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
const TFE_Op* value) {
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
auto s = tensorflow::unwrap(op)->SetAttrFunction(
attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1253,7 +1247,7 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
const char* data, size_t length) {
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1264,14 +1258,14 @@ void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
tensorflow::Tensor t;
status->status = TF_TensorToTensor(tensor, &t);
tensorflow::TensorInterface interface(t);
status->status = op->operation->SetAttrTensor(attr_name, &interface);
status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
}
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values) {
auto s =
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1279,7 +1273,8 @@ void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values) {
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
auto s =
tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1287,7 +1282,8 @@ void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
auto s =
tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1295,7 +1291,7 @@ void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
auto s = op->operation->SetAttrTypeList(
auto s = tensorflow::unwrap(op)->SetAttrTypeList(
attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
num_values);
if (!s.ok()) {
@ -1305,7 +1301,8 @@ void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values) {
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
auto s =
tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1314,19 +1311,14 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status) {
out_status->status =
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
attr_name, dims, num_dims, num_values);
}
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) {
absl::FixedArray<const tensorflow::AbstractOperationInterface*> values(
num_values);
for (int i = 0; i < num_values; ++i) {
values[i] = value[i]->operation;
}
auto s = op->operation->SetAttrFunctionList(attr_name,
{values.data(), values.size()});
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1341,12 +1333,13 @@ void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
return;
}
if (op == nullptr || op->operation == nullptr) {
if (op == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Got a null or uninitialized `op` argument");
return;
}
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
operation->MutableAttrs()->Set(attr_name, attr_value);
}
@ -1354,7 +1347,7 @@ TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status) {
int ret = -1;
status->status = op->operation->InputLength(input_name, &ret);
status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
return ret;
}
@ -1362,36 +1355,29 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status) {
int ret = -1;
status->status = op->operation->OutputLength(output_name, &ret);
status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
return ret;
}
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
*num_retvals);
status->status = op->operation->Execute(absl::MakeSpan(handles), num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle{handles[i]};
}
status->status = tensorflow::unwrap(op)->Execute(
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
}
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
auto* result = ctx->context->CopyTensorHandleToDevice(h->handle, device_name,
&status->status);
auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice(
tensorflow::unwrap(h), device_name, &status->status);
if (status->status.ok()) {
return new TFE_TensorHandle{result};
return tensorflow::wrap(result);
}
return nullptr;
}
@ -1406,39 +1392,39 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
return;
}
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function->fdef);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->RemoveFunction(name);
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->FindFunctionDef(name) != nullptr;
}
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
}
@ -1446,13 +1432,13 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
TF_Status* status) {
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(t)};
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
}
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*context->MetadataMu());
@ -1475,20 +1461,21 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
void TFE_ContextStartStep(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->StartStep();
}
void TFE_ContextEndStep(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->EndStep();
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (const auto& attribute : m) {
destination->Set(attribute.first, attribute.second);
@ -1576,33 +1563,34 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
const string& name() override { return name_; }
tensorflow::Status CopyTensorToDevice(
tensorflow::TensorHandle* tensor,
tensorflow::TensorHandle* handle,
tensorflow::TensorHandle** result) override {
tensor->Ref();
TFE_TensorHandle tensor_handle{tensor};
handle->Ref();
TF_Status status;
TFE_TensorHandle* result_handle =
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
tensor_handle.handle->Release();
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
context_, tensorflow::wrap(handle), &status, info_);
handle->Release();
if (!status.status.ok()) return status.status;
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
*result = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(result_handle));
(*result)->Ref();
TFE_DeleteTensorHandle(result_handle);
return status.status;
}
tensorflow::Status CopyTensorFromDevice(
tensorflow::TensorHandle* tensor,
tensorflow::TensorHandle* handle,
const tensorflow::string& target_device_name,
tensorflow::TensorHandle** result) override {
TF_Status status;
tensor->Ref();
TFE_TensorHandle tensor_handle{tensor};
handle->Ref();
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
tensor_handle.handle->Release();
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
info_);
handle->Release();
if (!status.status.ok()) return status.status;
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
*result = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(result_handle));
(*result)->Ref();
TFE_DeleteTensorHandle(result_handle);
return status.status;
@ -1615,7 +1603,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
inputs.reserve(op->Inputs().size());
for (int i = 0; i < op->Inputs().size(); ++i) {
op->Inputs()[i]->Ref();
inputs.push_back(new TFE_TensorHandle{op->Inputs()[i]});
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
@ -1624,7 +1612,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
&attributes, num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle);
retvals[i] = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(outputs[i]));
retvals[i]->Ref();
TFE_DeleteTensorHandle(outputs[i]);
}
@ -1652,7 +1641,7 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
auto custom_device =
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status =
context->RegisterCustomDevice(device_name, std::move(custom_device));
}

View File

@ -57,7 +57,8 @@ extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* h, TF_Status* status) {
tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
tensorflow::TensorHandle* handle =
TensorHandleFromInterface(tensorflow::unwrap(h));
const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {

View File

@ -19,6 +19,9 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
@ -34,9 +37,10 @@ using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
op_to_reset->operation->Clear();
status->status =
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
tensorflow::AbstractOperationInterface* op =
tensorflow::unwrap(op_to_reset);
op->Clear();
status->status = op->Reset(op_or_function_name, raw_device_name);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
@ -45,13 +49,13 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
}
@ -483,7 +487,7 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalMirroringPolicy(
static_cast<tensorflow::ContextMirroringPolicy>(policy));
}
@ -494,7 +498,7 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
}
@ -530,7 +534,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation);
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
operation->SetCancellationManager(
&cancellation_manager->cancellation_manager);
status->status = tensorflow::Status::OK();
@ -557,19 +561,19 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetExecutorForThread(executor->executor());
}
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return new TFE_Executor(&context->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
@ -585,7 +589,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto* function_def = context->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
@ -611,13 +615,14 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
if (ctx == nullptr || ctx->context == nullptr) {
if (ctx == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid Context");
return nullptr;
}
tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor(
static_cast<tensorflow::DataType>(dtype), dimvec);
tensorflow::AbstractTensorInterface* t =
tensorflow::unwrap(ctx)->CreateTensor(
static_cast<tensorflow::DataType>(dtype), dimvec);
if (t == nullptr) {
status->status =
@ -630,5 +635,6 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
TF_Status* status) {
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)};
return tensorflow::wrap(
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
}

View File

@ -19,13 +19,10 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_context_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_op_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" // IWYU pragma: export
// TODO(b/154564140): Move this to its own header. This requires splitting
// c_api_experimental.h

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h"
@ -233,7 +234,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
@ -245,7 +247,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async) {
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}

View File

@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/function.pb.h"
@ -416,8 +418,10 @@ void TensorHandleSilentCopy(bool async,
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle);
auto cpu_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
auto gpu_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
@ -1346,7 +1350,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation);
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
operation->Attrs().FillAttrValueMap(&attr_values);
return attr_values;
}
@ -1482,10 +1486,10 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->operation->OpDef());
CHECK(tensorflow::unwrap(concatOp)->OpDef());
TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->operation->OpDef())
EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -1590,7 +1594,7 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
// There is currently no API to fetch attributes from an operation, fetching
// happens only as an implementation detail of custom devices.
tensorflow::EagerOperation* operation =
OperationFromInterface(var_op->operation);
OperationFromInterface(tensorflow::unwrap(var_op));
TFE_OpAttrs attributes{&operation->Attrs()};
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
@ -1606,7 +1610,7 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(copy_op->operation);
tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
@ -1630,7 +1634,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
// There is currently no API to fetch attributes from an operation, fetching
// happens only as an implementation detail of custom devices.
tensorflow::EagerOperation* operation =
OperationFromInterface(var_op->operation);
OperationFromInterface(tensorflow::unwrap(var_op));
TFE_OpAttrs attributes{&operation->Attrs()};
TF_Buffer* serialized_attr_values = TF_NewBuffer();
@ -1657,7 +1661,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(var_op_2->operation);
tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2));
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());

View File

@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
@ -41,12 +43,12 @@ struct TfDlManagedTensorCtx {
// Gets tensor from eager tensor handle.
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle);
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (handle->Type() != TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support ", handle->TypeString(), " tensor");
@ -107,7 +109,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
// Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status);
const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type;

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/context_interface.h"
// Wraps a pointer to a context implementation.
@ -23,8 +24,12 @@ limitations under the License.
// interface cannot destruct the underlying context object. Instead, call
// TFE_DeleteContext who calls Release() on the context pointer and deletes
// the TFE_Context structure.
struct TFE_Context {
tensorflow::AbstractContextInterface* context;
};
typedef struct TFE_Context TFE_Context;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_

View File

@ -23,14 +23,15 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/attr_value.pb.h"
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
typedef struct TFE_Context TFE_Context;
typedef struct TFE_Op TFE_Op;
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : attributes(nullptr) {}

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/operation_interface.h"
// Wraps a pointer to an operation implementation.
@ -23,8 +24,13 @@ limitations under the License.
// interface cannot destruct the underlying operation object. Instead, call
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
// the TFE_Op structure.
struct TFE_Op {
tensorflow::AbstractOperationInterface* operation;
};
typedef struct TFE_Op TFE_Op;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
// Wraps a pointer to a tensor handle implementation.
@ -23,8 +24,15 @@ limitations under the License.
// interface cannot destruct the underlying handle object. Instead, call
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
// the TFE_TensorHandle structure.
struct TFE_TensorHandle {
tensorflow::AbstractTensorHandleInterface* handle;
};
typedef struct TFE_TensorHandle TFE_TensorHandle;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
TFE_TensorHandle);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
TFE_TensorHandle*);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_

View File

@ -22,13 +22,6 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "conversion_macros",
hdrs = [
"conversion_macros.h",
],
)
cc_library(
name = "concrete_function",
srcs = [
@ -51,6 +44,7 @@ cc_library(
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
@ -93,7 +87,7 @@ cc_library(
"concrete_function_type.h",
],
deps = [
":conversion_macros",
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
],
)
@ -123,7 +117,7 @@ cc_library(
"function_metadata_type.h",
],
deps = [
":conversion_macros",
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
],
)

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
@ -24,7 +25,8 @@ limitations under the License.
extern "C" {
TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
return tensorflow::wrap(&tensorflow::unwrap(func)->GetFunctionMetadata());
return tensorflow::wrap(const_cast<tensorflow::FunctionMetadata*>(
&tensorflow::unwrap(func)->GetFunctionMetadata()));
}
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
@ -34,7 +36,7 @@ TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
}
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
return new TFE_Op{tensorflow::unwrap(func)->GetCallOp()};
return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp());
}
} // end extern "C"

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.

View File

@ -1,28 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline wrapper *wrap(const cpp_impl *i) { \
return reinterpret_cast<wrapper *>(const_cast<cpp_impl *>(i)); \
}
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
typedef struct TF_FunctionMetadata TF_FunctionMetadata;

View File

@ -36,7 +36,8 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
std::string saved_model_dir(dirname);
std::unique_ptr<tensorflow::SavedModelAPI> result =
ctx->context->LoadSavedModelAPI(dirname, absl::nullopt, &status->status);
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
&status->status);
if (!status->status.ok()) {
return nullptr;
}
@ -54,8 +55,8 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
ctx->context->LoadSavedModelAPI(dirname, std::move(tagset),
&status->status);
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
&status->status);
if (!status->status.ok()) {
return nullptr;
}

View File

@ -903,7 +903,8 @@ cc_library(
":safe_ptr",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -1007,7 +1008,8 @@ cc_library(
":safe_ptr",
"//tensorflow/c:tensor_interface",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/python_runtime:headers",

View File

@ -40,6 +40,9 @@ cc_library(
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:dlpack",
"//tensorflow/c/eager:tape",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/hash/hash.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/platform/logging.h"
@ -41,7 +41,7 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
PyObject* value, tensorflow::DataType dtype,
absl::string_view device_name) const {
CHECK_NOTNULL(value);
const auto& it = cache.find(Key{PyObjectPtr{value}, dtype, device_name});
const auto it = cache.find(Key{PyObjectPtr{value}, dtype, device_name});
if (it == cache.end()) {
scalar_cache_misses->GetCell()->IncrementBy(1);
return nullptr;
@ -49,7 +49,7 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
scalar_cache_hits->GetCell()->IncrementBy(1);
auto* h = it->second;
return new TFE_TensorHandle{h->handle->Copy()};
return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
}
void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
@ -57,7 +57,7 @@ void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
TFE_TensorHandle* h) {
Py_INCREF(value);
cache.emplace(Key{PyObjectPtr{value}, dtype, device_name},
new TFE_TensorHandle{h->handle->Copy()});
tensorflow::wrap(tensorflow::unwrap(h)->Copy()));
}
void TFE_TensorHandleCache::Clear() {

View File

@ -379,7 +379,7 @@ void TFE_Py_EnableInteractivePythonLogging();
// Py_None.
//
// This function is not thread-safe.
PyObject* TFE_Py_SetEagerContext(PyObject* python_context);
PyObject* TFE_Py_SetEagerContext(PyObject* py_context);
// Returns the current eager Context object (defined in eager/context.py)
// that was last set using TFE_Py_SetEagerContext.

View File

@ -24,6 +24,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tape.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@ -80,9 +83,10 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
auto op = ReleaseThreadLocalOp(ctx);
if (!op) {
op.reset(new TFE_Op{ctx->context->CreateOperation()});
op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
}
status->status = op->operation->Reset(op_or_function_name, raw_device_name);
status->status =
tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
if (!status->status.ok()) {
op.reset();
}
@ -91,7 +95,7 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
if (op) {
op->operation->Clear();
tensorflow::unwrap(op)->Clear();
thread_local_eager_operation_map[ctx].reset(op);
}
}
@ -1500,7 +1504,7 @@ static PyTypeObject TFE_Py_Tape_Type = {
sizeof(TFE_Py_Tape), /* tp_basicsize */
0, /* tp_itemsize */
&TFE_Py_Tape_Delete, /* tp_dealloc */
0, /* tp_print */
nullptr, /* tp_print */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
@ -1538,7 +1542,7 @@ static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */
0, /* tp_itemsize */
&TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */
0, /* tp_print */
nullptr, /* tp_print */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
@ -1573,7 +1577,7 @@ static PyTypeObject TFE_Py_VariableWatcher_Type = {
sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
0, /* tp_itemsize */
&TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
0, /* tp_print */
nullptr, /* tp_print */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
@ -1990,21 +1994,22 @@ bool ListContainsNone(PyObject* list) {
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::AbstractTensorHandleInterface* handle =
tensorflow::unwrap(EagerTensor_Handle(tensor));
tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(t->handle->DataType());
static_cast<tensorflow::DataType>(handle->DataType());
if (dtype == tensorflow::DT_VARIANT) {
return PyTapeTensor(id, dtype, tensor);
}
tensorflow::TensorShape tensor_shape;
int num_dims;
tensorflow::Status status = t->handle->NumDims(&num_dims);
tensorflow::Status status = handle->NumDims(&num_dims);
if (status.ok()) {
for (int i = 0; i < num_dims; ++i) {
tensorflow::int64 dim_size;
status = t->handle->Dim(i, &dim_size);
status = handle->Dim(i, &dim_size);
if (!status.ok()) break;
tensor_shape.AddDim(dim_size);
}
@ -3511,7 +3516,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
return nullptr;
}
const tensorflow::OpDef* op_def = op->operation->OpDef();
const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
if (op_def == nullptr) return nullptr;
if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
@ -3850,14 +3855,15 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
bool include_tensor_ranks_only,
EncodeResult* result) {
if (EagerTensor_CheckExact(arg)) {
TFE_TensorHandle* t = EagerTensor_Handle(arg);
tensorflow::AbstractTensorHandleInterface* handle =
tensorflow::unwrap(EagerTensor_Handle(arg));
absl::StrAppend(&result->str, kDType,
static_cast<tensorflow::DataType>(t->handle->DataType()));
static_cast<tensorflow::DataType>(handle->DataType()));
absl::StrAppend(&result->str, kShape);
int num_dims;
tensorflow::Status status = t->handle->NumDims(&num_dims);
tensorflow::Status status = handle->NumDims(&num_dims);
if (!status.ok()) return status;
if (include_tensor_ranks_only) {
@ -3865,7 +3871,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
} else {
for (int i = 0; i < num_dims; ++i) {
tensorflow::int64 dim_size;
status = t->handle->Dim(i, &dim_size);
status = handle->Dim(i, &dim_size);
if (!status.ok()) return status;
absl::StrAppend(&result->str, dim_size, kShapeDelim);
}

View File

@ -26,7 +26,8 @@ limitations under the License.
#include "numpy/arrayobject.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
@ -95,8 +96,8 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
if (call->eager) {
Tensor t = call->ins[i];
arg = EagerTensorFromHandle(
new TFE_TensorHandle{TensorHandle::CreateLocalHandle(
std::move(t), ctx->CanonicalDevice(device), nullptr, ctx)});
tensorflow::wrap(TensorHandle::CreateLocalHandle(
std::move(t), ctx->CanonicalDevice(device), nullptr, ctx)));
if (arg == nullptr) {
Py_DECREF(lst);
return errors::Internal("Unable to procure EagerTensor from Tensor.");
@ -146,7 +147,7 @@ tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
const Device* expected_device,
const Tensor** output_tensor) {
tensorflow::TensorHandle* handle = tensorflow::TensorHandleFromInterface(
EagerTensor_Handle(eager_tensor)->handle);
tensorflow::unwrap(EagerTensor_Handle(eager_tensor)));
if (VariantDeviceIsCustom(handle->device())) {
return errors::Unimplemented(
"Custom devices are currently not supported with PyFuncs.");
@ -196,7 +197,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
CHECK_NE(ctx, nullptr);
EagerContext* context = ContextFromInterface(ctx->context);
EagerContext* context = ContextFromInterface(tensorflow::unwrap(ctx));
TF_RETURN_IF_ERROR(MakeArgTuple(call, context, &args));
new_executor.reset(new EagerExecutor(call->eager_async));
old_executor = &context->Executor();
@ -236,7 +237,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
if (new_executor != nullptr) {
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
EagerContext* context = ContextFromInterface(ctx->context);
EagerContext* context = ContextFromInterface(tensorflow::unwrap(ctx));
s.Update(new_executor->WaitForAllPendingNodes());
context->SetExecutorForThread(old_executor);
}

View File

@ -15,7 +15,8 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_seq_tensor.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/tensor.h"
@ -305,7 +306,7 @@ struct Converter {
}
}
}
*h = new TFE_TensorHandle{ctx->context->CreateLocalHandle(t)};
*h = tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(t));
t->Release();
return Status::OK();
}
@ -316,12 +317,12 @@ struct Converter {
template <>
struct ConverterTraits<int64> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, int64 value) {
return ctx->context->CreateInt64Scalar(value);
return tensorflow::unwrap(ctx)->CreateInt64Scalar(value);
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_INT64, dim_sizes);
return tensorflow::unwrap(ctx)->CreateTensor(DT_INT64, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, int64* out) {
@ -356,12 +357,12 @@ typedef Converter<int64> Int64Converter;
template <>
struct ConverterTraits<uint64> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, uint64 value) {
return ctx->context->CreateUint64Scalar(value);
return tensorflow::unwrap(ctx)->CreateUint64Scalar(value);
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_UINT64, dim_sizes);
return tensorflow::unwrap(ctx)->CreateTensor(DT_UINT64, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, uint64* out) {
@ -393,12 +394,12 @@ typedef Converter<uint64> UInt64Converter;
template <>
struct ConverterTraits<int32> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, int32 value) {
return ctx->context->CreateInt32Scalar(value);
return tensorflow::unwrap(ctx)->CreateInt32Scalar(value);
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_INT32, dim_sizes);
return tensorflow::unwrap(ctx)->CreateTensor(DT_INT32, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, int32* out) {
@ -500,12 +501,12 @@ static const char* ConvertOneFloat(PyObject* v, T* out) {
template <>
struct ConverterTraits<float> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, float value) {
return ctx->context->CreateFloatScalar(value);
return tensorflow::unwrap(ctx)->CreateFloatScalar(value);
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_FLOAT, dim_sizes);
return tensorflow::unwrap(ctx)->CreateTensor(DT_FLOAT, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, float* out) {
@ -516,12 +517,12 @@ struct ConverterTraits<float> {
template <>
struct ConverterTraits<double> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, double value) {
return ctx->context->CreateDoubleScalar(value);
return tensorflow::unwrap(ctx)->CreateDoubleScalar(value);
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_DOUBLE, dim_sizes);
return tensorflow::unwrap(ctx)->CreateTensor(DT_DOUBLE, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, double* out) {
@ -536,12 +537,12 @@ template <>
struct ConverterTraits<Eigen::half> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx,
Eigen::half value) {
return ctx->context->CreateHalfScalar(value);
return tensorflow::unwrap(ctx)->CreateHalfScalar(value);
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_HALF, dim_sizes);
return tensorflow::unwrap(ctx)->CreateTensor(DT_HALF, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, Eigen::half* out) {
@ -557,12 +558,12 @@ template <>
struct ConverterTraits<tstring> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx,
tstring value) {
return ctx->context->CreateStringScalar(value);
return tensorflow::unwrap(ctx)->CreateStringScalar(value);
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_STRING, dim_sizes);
return tensorflow::unwrap(ctx)->CreateTensor(DT_STRING, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, tstring* out) {
@ -624,12 +625,12 @@ template <>
struct ConverterTraits<complex128> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx,
complex128 value) {
return ctx->context->CreateComplex128Scalar(value);
return tensorflow::unwrap(ctx)->CreateComplex128Scalar(value);
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_COMPLEX128, dim_sizes);
return tensorflow::unwrap(ctx)->CreateTensor(DT_COMPLEX128, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, complex128* out) {
@ -652,12 +653,12 @@ typedef Converter<complex128> Complex128Converter;
template <>
struct ConverterTraits<bool> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, bool value) {
return ctx->context->CreateBoolScalar(value);
return tensorflow::unwrap(ctx)->CreateBoolScalar(value);
}
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_BOOL, dim_sizes);
return tensorflow::unwrap(ctx)->CreateTensor(DT_BOOL, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, bool* out) {
@ -692,7 +693,7 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
}
TensorInterface t(std::move(tensor));
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(&t)};
return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t));
}
} // namespace
@ -877,7 +878,7 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
TensorShape(state.inferred_shape));
TensorInterface t(std::move(tensor));
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(&t)};
return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t));
}
default:

View File

@ -270,7 +270,6 @@ static py::object TFE_ClearScalarCache() {
// are only assigning this to functions that return opaque types.
PYBIND11_MODULE(_pywrap_tfe, m) {
py::class_<TFE_Context> TFE_Context_class(m, "TFE_Context");
py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
"TFE_ContextOptions");
@ -760,7 +759,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TFE_ContextStartStep", [](py::handle& o) {
TFE_ContextStartStep(tensorflow::InputTFE_Context(o.ptr()));
});
m.def("TFE_ContextEndStep", &TFE_ContextEndStep);
m.def("TFE_ContextEndStep", [](py::handle& o) {
TFE_ContextEndStep(tensorflow::InputTFE_Context(o.ptr()));
});
m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
});