Avoid pointer indirection in handle, context & op

Since the struct lifetime is bound to the wrapped pointer this is fine.

PiperOrigin-RevId: 308941521
Change-Id: I0604fff4fcba6a03cc4a2242ab9f182fbfbf8bae
This commit is contained in:
Gaurav Jain 2020-04-28 18:54:46 -07:00 committed by TensorFlower Gardener
parent 510f0f9a6c
commit 9ec6997dfb
29 changed files with 352 additions and 301 deletions

View File

@ -334,6 +334,9 @@ tf_cuda_library(
":checkpoint_reader", ":checkpoint_reader",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:flags",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -729,3 +732,11 @@ tf_cuda_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "conversion_macros",
hdrs = [
"conversion_macros.h",
],
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -21,6 +21,9 @@ limitations under the License.
#include "tensorflow/c/checkpoint_reader.h" #include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
@ -686,8 +689,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
status->status = tensorflow::Status::OK(); status->status = tensorflow::Status::OK();
return new TFE_TensorHandle{ return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
} }
namespace { namespace {
@ -708,7 +710,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
// New server created for new server_def. Unused if updating server_def. // New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server = tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer()); dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) { if (grpc_server == nullptr) {
@ -822,14 +824,13 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items; const int num_inputs = input_shapes->num_items;
NodeDef node_def; NodeDef node_def;
node_def.set_name(tfe_op->operation->Name()); tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
node_def.set_op(tfe_op->operation->Name()); node_def.set_name(op->Name());
node_def.set_op(op->Name());
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input"); node_def.add_input("dummy_input");
} }
OperationFromInterface(tfe_op->operation) OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());
const tensorflow::OpRegistrationData* op_reg_data; const tensorflow::OpRegistrationData* op_reg_data;
status->status = status->status =

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
#define TENSORFLOW_C_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline const cpp_impl *unwrap(const wrapper *w) { \
return reinterpret_cast<const cpp_impl *>(w); \
} \
\
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); }
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_

View File

@ -50,7 +50,6 @@ tf_cuda_library(
":tfe_tensor_debug_info_internal", ":tfe_tensor_debug_info_internal",
":tfe_tensorhandle_internal", ":tfe_tensorhandle_internal",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant", "@com_google_absl//absl/types:variant",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
@ -110,13 +109,10 @@ filegroup(
"operation_interface.h", "operation_interface.h",
"tensor_handle_interface.h", "tensor_handle_interface.h",
"tfe_cancellation_manager_internal.h", "tfe_cancellation_manager_internal.h",
"tfe_context_internal.h",
"tfe_executor_internal.h", "tfe_executor_internal.h",
"tfe_monitoring_internal.h", "tfe_monitoring_internal.h",
"tfe_op_attrs_internal.h", "tfe_op_attrs_internal.h",
"tfe_op_internal.h",
"tfe_tensor_debug_info_internal.h", "tfe_tensor_debug_info_internal.h",
"tfe_tensorhandle_internal.h",
], ],
visibility = [ visibility = [
"//tensorflow/core:__pkg__", "//tensorflow/core:__pkg__",
@ -205,6 +201,7 @@ cc_library(
], ],
deps = [ deps = [
":context_interface", ":context_interface",
"//tensorflow/c:conversion_macros",
], ],
) )
@ -249,8 +246,6 @@ cc_library(
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
":tfe_context_internal",
":tfe_op_internal",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:attr_builder",
@ -265,6 +260,7 @@ cc_library(
], ],
deps = [ deps = [
":operation_interface", ":operation_interface",
"//tensorflow/c:conversion_macros",
], ],
) )
@ -287,6 +283,7 @@ cc_library(
], ],
deps = [ deps = [
":tensor_handle_interface", ":tensor_handle_interface",
"//tensorflow/c:conversion_macros",
], ],
) )
@ -327,6 +324,8 @@ tf_cuda_cc_test(
":c_api_experimental", ":c_api_experimental",
":c_api_internal", ":c_api_internal",
":c_api_test_util", ":c_api_test_util",
":tfe_op_internal",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util", "//tensorflow/c:c_test_util",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
@ -351,6 +350,7 @@ tf_cuda_cc_test(
":c_api_experimental", ":c_api_experimental",
":c_api_internal", ":c_api_internal",
":c_api_test_util", ":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util", "//tensorflow/c:c_test_util",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -384,6 +384,9 @@ tf_cuda_library(
"//conditions:default": [ "//conditions:default": [
":c_api", ":c_api",
":c_api_internal", ":c_api_internal",
":tfe_context_internal",
":tfe_op_internal",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal", "//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
@ -548,8 +551,9 @@ cc_library(
deps = [ deps = [
":c_api", ":c_api",
":c_api_experimental", ":c_api_experimental",
":c_api_internal", ":tfe_tensorhandle_internal",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_status_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",

View File

@ -26,7 +26,6 @@ limitations under the License.
// clang-format on // clang-format on
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
@ -34,6 +33,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
#ifdef PLATFORM_GOOGLE #ifdef PLATFORM_GOOGLE
#include "tensorflow/c/eager/c_api_tfrt.h" #include "tensorflow/c/eager/c_api_tfrt.h"
@ -298,7 +300,7 @@ tensorflow::Status CreateRemoteContexts(
std::vector<bool> filtered_device_mask; std::vector<bool> filtered_device_mask;
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->FilterDevicesForRemoteWorkers( context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(), remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask); &filtered_device_mask);
@ -383,7 +385,7 @@ tensorflow::Status UpdateRemoteContexts(
std::vector<bool> filtered_device_mask; std::vector<bool> filtered_device_mask;
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->FilterDevicesForRemoteWorkers( context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(), remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask); &filtered_device_mask);
@ -464,7 +466,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// New server created for new server_def. Unused if updating server_def. // New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server; std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server; tensorflow::GrpcServer* grpc_server;
if (reset_context) { if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
@ -684,7 +686,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) { if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE #ifdef PLATFORM_GOOGLE
status->status = tensorflow::Status::OK(); status->status = tensorflow::Status::OK();
return new TFE_Context{new tfrt::ContextInterface()}; return tensorflow::wrap(new tfrt::ContextInterface());
#else #else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr; return nullptr;
@ -701,14 +703,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get()); new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context{new tensorflow::EagerContext( return tensorflow::wrap(new tensorflow::EagerContext(
opts->session_options.options, opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>( static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy), opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy), static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r, /*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator())}; tensorflow::GetDefaultCustomKernelCreator()));
} }
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -719,14 +721,14 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr); new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context{new tensorflow::EagerContext( return tensorflow::wrap(new tensorflow::EagerContext(
opts->session_options.options, opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>( static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy), opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy), static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr, opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r, /*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator())}; tensorflow::GetDefaultCustomKernelCreator()));
} }
void TFE_DeleteContext(TFE_Context* ctx) { void TFE_DeleteContext(TFE_Context* ctx) {
@ -734,22 +736,19 @@ void TFE_DeleteContext(TFE_Context* ctx) {
return; return;
} }
// context->RefCountIsOne() should be true here. // ctx->RefCountIsOne() should be true here.
// TODO(iga): Remove EagerContext refcounting. tensorflow::unwrap(ctx)->Release();
ctx->context->Release();
delete ctx;
} }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* l = new TF_DeviceList; TF_DeviceList* l = new TF_DeviceList;
ctx->context->ListDevices(&l->response); tensorflow::unwrap(ctx)->ListDevices(&l->response);
return l; return l;
} }
void TFE_ContextClearCaches(TFE_Context* ctx) { void TFE_ContextClearCaches(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->ClearCachesAndThreadExecutors(); context->ClearCachesAndThreadExecutors();
} }
@ -772,7 +771,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
if (server_def.has_cluster_device_filters()) { if (server_def.has_cluster_device_filters()) {
const auto& cdf = server_def.cluster_device_filters(); const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) { for (const auto& jdf : cdf.jobs()) {
const string& remote_prefix = "/job:" + jdf.name() + "/task:"; const string remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) { for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first; const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size()); std::vector<string> device_filters(tdf.second.device_filters_size());
@ -781,7 +780,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
} }
const string remote_worker = remote_prefix + std::to_string(task_index); const string remote_worker = remote_prefix + std::to_string(task_index);
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = status->status =
context->SetRemoteDeviceFilters(remote_worker, device_filters); context->SetRemoteDeviceFilters(remote_worker, device_filters);
} }
@ -803,7 +802,7 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def; tensorflow::ServerDef server_def;
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
if (!server_def.ParseFromArray(proto, proto_len)) { if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer"); "Invalid tensorflow.ServerDef protocol buffer");
@ -833,7 +832,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
return false; return false;
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server = tensorflow::GrpcServer* grpc_server =
static_cast<tensorflow::GrpcServer*>(context->GetServer()); static_cast<tensorflow::GrpcServer*>(context->GetServer());
@ -889,7 +888,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
status->status = tensorflow::Status::OK(); status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->SyncExecutors(); status->status = context->SyncExecutors();
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
} }
@ -897,7 +896,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
void TFE_ContextSetThreadLocalDevicePlacementPolicy( void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalDevicePlacementPolicy( context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
} }
@ -908,7 +907,7 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) { TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextDevicePlacementPolicy>( return static_cast<TFE_ContextDevicePlacementPolicy>(
context->GetDevicePlacementPolicy()); context->GetDevicePlacementPolicy());
} }
@ -918,8 +917,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
status->status = tensorflow::TF_TensorToTensor(t, &tensor); status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr; if (!status->status.ok()) return nullptr;
return new TFE_TensorHandle{ return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
} }
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
@ -927,84 +925,84 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
tensorflow::profiler::TraceMe activity( tensorflow::profiler::TraceMe activity(
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo); "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
if (h->handle) { if (h) {
h->handle->Release(); tensorflow::unwrap(h)->Release();
} }
delete h;
} }
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->handle->DataType()); return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
} }
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1; return -1;
} }
int num_dims = -1; int num_dims = -1;
status->status = h->handle->NumDims(&num_dims); status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
return num_dims; return num_dims;
} }
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1; return -1;
} }
int64 num_elements = -1; int64 num_elements = -1;
status->status = h->handle->NumElements(&num_elements); status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
return num_elements; return num_elements;
} }
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) { TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1; return -1;
} }
int64 dim = -1; int64 dim = -1;
status->status = h->handle->Dim(dim_index, &dim); status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
return dim; return dim;
} }
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr; return nullptr;
} }
return h->handle->DeviceName(&status->status); return tensorflow::unwrap(h)->DeviceName(&status->status);
} }
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h, const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr; return nullptr;
} }
return h->handle->BackingDeviceName(&status->status); return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
} }
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) { TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr; return nullptr;
} }
return new TFE_TensorHandle{h->handle->Copy()}; return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
} }
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr; return nullptr;
} }
tensorflow::AbstractTensorInterface* t = h->handle->Resolve(&status->status); tensorflow::AbstractTensorInterface* t =
tensorflow::unwrap(h)->Resolve(&status->status);
if (t == nullptr) { if (t == nullptr) {
return nullptr; return nullptr;
} }
@ -1013,12 +1011,12 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
} }
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr; return nullptr;
} }
tensorflow::TensorHandle* handle = tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle); tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (VariantDeviceIsCustom(handle->device())) { if (VariantDeviceIsCustom(handle->device())) {
const tensorflow::Tensor* t; const tensorflow::Tensor* t;
status->status = handle->Tensor(&t); status->status = handle->Tensor(&t);
@ -1054,7 +1052,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void* deallocator_arg, TF_Status* status) { void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device = nullptr; tensorflow::Device* device = nullptr;
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->FindDeviceFromName(device_name, &device); status->status = context->FindDeviceFromName(device_name, &device);
tensorflow::CustomDevice* custom_device = nullptr; tensorflow::CustomDevice* custom_device = nullptr;
if (!status->status.ok()) { if (!status->status.ok()) {
@ -1080,11 +1078,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorShape(dimvec), buf); tensorflow::TensorShape(dimvec), buf);
buf->Unref(); buf->Unref();
if (custom_device == nullptr) { if (custom_device == nullptr) {
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle( return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), device, device, context)}; std::move(t), device, device, context));
} else { } else {
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle( return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context)}; std::move(t), custom_device, context));
} }
} }
@ -1093,12 +1091,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
// bytes of the memory pointed to by the device pointer returned above. // bytes of the memory pointed to by the device pointer returned above.
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return 0; return 0;
} }
tensorflow::TensorHandle* handle = tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle); tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (handle->Type() != tensorflow::TensorHandle::LOCAL) { if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDeviceMemorySize may not be called on a ", "TFE_TensorHandleDeviceMemorySize may not be called on a ",
@ -1115,12 +1113,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) { TF_Status* status) {
std::unique_ptr<TFE_Op> new_op(new TFE_Op{ctx->context->CreateOperation()}); tensorflow::AbstractOperationInterface* new_op =
status->status = new_op->operation->Reset(op_or_function_name, nullptr); tensorflow::unwrap(ctx)->CreateOperation();
status->status = new_op->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) { if (!status->status.ok()) {
new_op.reset(); new_op->Release();
new_op = nullptr;
} }
return new_op.release(); return tensorflow::wrap(new_op);
} }
void TFE_DeleteOp(TFE_Op* op) { void TFE_DeleteOp(TFE_Op* op) {
@ -1128,24 +1128,20 @@ void TFE_DeleteOp(TFE_Op* op) {
return; return;
} }
if (op->operation) { tensorflow::unwrap(op)->Release();
op->operation->Release();
}
delete op;
} }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = op->operation->SetDeviceName(device_name); status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
} }
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
return op->operation->DeviceName().c_str(); return tensorflow::unwrap(op)->DeviceName().c_str();
} }
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
#ifdef TENSORFLOW_EAGER_USE_XLA #ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Status s = op->operation->SetUseXla(enable); tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
if (!s.ok()) { if (!s.ok()) {
LOG(ERROR) << "Could not enable XLA compilation for op: " << s; LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
} }
@ -1156,18 +1152,13 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
} }
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
status->status = op->operation->AddInput(input->handle); status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
} }
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) { TF_Status* status) {
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles( status->status = tensorflow::unwrap(op)->AddInputList(
num_inputs); {tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
for (int i = 0; i < num_inputs; ++i) {
handles[i] = inputs[i]->handle;
}
status->status =
op->operation->AddInputList({handles.data(), handles.size()});
} }
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1175,8 +1166,8 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
TF_AttrType ret = TF_ATTR_INT; TF_AttrType ret = TF_ATTR_INT;
const tensorflow::AttrTypeMap* attr_types_; const tensorflow::AttrTypeMap* attr_types_;
bool is_function; bool is_function;
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(), status->status = tensorflow::AttrTypeMapForOp(
&attr_types_, &is_function); tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
if (!status->status.ok()) { if (!status->status.ok()) {
return ret; return ret;
} }
@ -1202,7 +1193,7 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
size_t length) { size_t length) {
auto s = op->operation->SetAttrString( auto s = tensorflow::unwrap(op)->SetAttrString(
attr_name, static_cast<const char*>(value), length); attr_name, static_cast<const char*>(value), length);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
@ -1210,29 +1201,30 @@ void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
} }
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
auto s = op->operation->SetAttrInt(attr_name, value); auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
} }
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
auto s = op->operation->SetAttrFloat(attr_name, value); auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
} }
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true); auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
(value == 0) ? false : true);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
} }
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
auto s = op->operation->SetAttrType(attr_name, auto s = tensorflow::unwrap(op)->SetAttrType(
static_cast<tensorflow::DataType>(value)); attr_name, static_cast<tensorflow::DataType>(value));
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
@ -1240,12 +1232,14 @@ void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
const int num_dims, TF_Status* out_status) { const int num_dims, TF_Status* out_status) {
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims); out_status->status =
tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
} }
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
const TFE_Op* value) { const TFE_Op* value) {
auto s = op->operation->SetAttrFunction(attr_name, value->operation); auto s = tensorflow::unwrap(op)->SetAttrFunction(
attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
@ -1253,7 +1247,7 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
const char* data, size_t length) { const char* data, size_t length) {
auto s = op->operation->SetAttrFunctionName(attr_name, data, length); auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
@ -1264,14 +1258,14 @@ void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
tensorflow::Tensor t; tensorflow::Tensor t;
status->status = TF_TensorToTensor(tensor, &t); status->status = TF_TensorToTensor(tensor, &t);
tensorflow::TensorInterface interface(t); tensorflow::TensorInterface interface(t);
status->status = op->operation->SetAttrTensor(attr_name, &interface); status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
} }
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const void* const* values, const size_t* lengths, const void* const* values, const size_t* lengths,
int num_values) { int num_values) {
auto s = auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
op->operation->SetAttrStringList(attr_name, values, lengths, num_values); num_values);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
@ -1279,7 +1273,8 @@ void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values) { const float* values, int num_values) {
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values); auto s =
tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
@ -1287,7 +1282,8 @@ void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) { const int64_t* values, int num_values) {
auto s = op->operation->SetAttrIntList(attr_name, values, num_values); auto s =
tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
@ -1295,7 +1291,7 @@ void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) { const TF_DataType* values, int num_values) {
auto s = op->operation->SetAttrTypeList( auto s = tensorflow::unwrap(op)->SetAttrTypeList(
attr_name, reinterpret_cast<const tensorflow::DataType*>(values), attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
num_values); num_values);
if (!s.ok()) { if (!s.ok()) {
@ -1305,7 +1301,8 @@ void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values) { const unsigned char* values, int num_values) {
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values); auto s =
tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
@ -1314,19 +1311,14 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims, const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status) { int num_values, TF_Status* out_status) {
out_status->status = out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values); attr_name, dims, num_dims, num_values);
} }
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) { const TFE_Op** value, int num_values) {
absl::FixedArray<const tensorflow::AbstractOperationInterface*> values( auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
num_values); attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
for (int i = 0; i < num_values; ++i) {
values[i] = value[i]->operation;
}
auto s = op->operation->SetAttrFunctionList(attr_name,
{values.data(), values.size()});
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
@ -1341,12 +1333,13 @@ void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto"); tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
return; return;
} }
if (op == nullptr || op->operation == nullptr) { if (op == nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Got a null or uninitialized `op` argument"); "Got a null or uninitialized `op` argument");
return; return;
} }
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation); tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
operation->MutableAttrs()->Set(attr_name, attr_value); operation->MutableAttrs()->Set(attr_name, attr_value);
} }
@ -1354,7 +1347,7 @@ TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name, const char* input_name,
TF_Status* status) { TF_Status* status) {
int ret = -1; int ret = -1;
status->status = op->operation->InputLength(input_name, &ret); status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
return ret; return ret;
} }
@ -1362,36 +1355,29 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name, const char* output_name,
TF_Status* status) { TF_Status* status) {
int ret = -1; int ret = -1;
status->status = op->operation->OutputLength(output_name, &ret); status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
return ret; return ret;
} }
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) { TF_Status* status) {
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles( status->status = tensorflow::unwrap(op)->Execute(
*num_retvals); absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
status->status = op->operation->Execute(absl::MakeSpan(handles), num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle{handles[i]};
}
} }
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx, TFE_Context* ctx,
const char* device_name, const char* device_name,
TF_Status* status) { TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr; return nullptr;
} }
auto* result = ctx->context->CopyTensorHandleToDevice(h->handle, device_name, auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice(
&status->status); tensorflow::unwrap(h), device_name, &status->status);
if (status->status.ok()) { if (status->status.ok()) {
return new TFE_TensorHandle{result}; return tensorflow::wrap(result);
} }
return nullptr; return nullptr;
} }
@ -1406,39 +1392,39 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
return; return;
} }
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function_def); status->status = context->AddFunctionDef(function_def);
} }
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function->fdef); status->status = context->AddFunctionDef(function->fdef);
} }
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->RemoveFunction(name); status->status = context->RemoveFunction(name);
} }
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->FindFunctionDef(name) != nullptr; return context->FindFunctionDef(name) != nullptr;
} }
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true); context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false); context->SetShouldStoreGraphs(false);
} }
@ -1446,13 +1432,13 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
TF_Status* status) { TF_Status* status) {
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(t)}; return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
} }
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->Executor().WaitForAllPendingNodes(); status->status = context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return; if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*context->MetadataMu()); tensorflow::mutex_lock ml(*context->MetadataMu());
@ -1475,20 +1461,21 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
void TFE_ContextStartStep(TFE_Context* ctx) { void TFE_ContextStartStep(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->StartStep(); context->StartStep();
} }
void TFE_ContextEndStep(TFE_Context* ctx) { void TFE_ContextEndStep(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->EndStep(); context->EndStep();
} }
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m; tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m); attrs->attributes->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation); tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs(); tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (const auto& attribute : m) { for (const auto& attribute : m) {
destination->Set(attribute.first, attribute.second); destination->Set(attribute.first, attribute.second);
@ -1576,33 +1563,34 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
const string& name() override { return name_; } const string& name() override { return name_; }
tensorflow::Status CopyTensorToDevice( tensorflow::Status CopyTensorToDevice(
tensorflow::TensorHandle* tensor, tensorflow::TensorHandle* handle,
tensorflow::TensorHandle** result) override { tensorflow::TensorHandle** result) override {
tensor->Ref(); handle->Ref();
TFE_TensorHandle tensor_handle{tensor};
TF_Status status; TF_Status status;
TFE_TensorHandle* result_handle = TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_); context_, tensorflow::wrap(handle), &status, info_);
tensor_handle.handle->Release(); handle->Release();
if (!status.status.ok()) return status.status; if (!status.status.ok()) return status.status;
*result = tensorflow::TensorHandleFromInterface(result_handle->handle); *result = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(result_handle));
(*result)->Ref(); (*result)->Ref();
TFE_DeleteTensorHandle(result_handle); TFE_DeleteTensorHandle(result_handle);
return status.status; return status.status;
} }
tensorflow::Status CopyTensorFromDevice( tensorflow::Status CopyTensorFromDevice(
tensorflow::TensorHandle* tensor, tensorflow::TensorHandle* handle,
const tensorflow::string& target_device_name, const tensorflow::string& target_device_name,
tensorflow::TensorHandle** result) override { tensorflow::TensorHandle** result) override {
TF_Status status; TF_Status status;
tensor->Ref(); handle->Ref();
TFE_TensorHandle tensor_handle{tensor};
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device( TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
context_, &tensor_handle, target_device_name.c_str(), &status, info_); context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
tensor_handle.handle->Release(); info_);
handle->Release();
if (!status.status.ok()) return status.status; if (!status.status.ok()) return status.status;
*result = tensorflow::TensorHandleFromInterface(result_handle->handle); *result = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(result_handle));
(*result)->Ref(); (*result)->Ref();
TFE_DeleteTensorHandle(result_handle); TFE_DeleteTensorHandle(result_handle);
return status.status; return status.status;
@ -1615,7 +1603,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
inputs.reserve(op->Inputs().size()); inputs.reserve(op->Inputs().size());
for (int i = 0; i < op->Inputs().size(); ++i) { for (int i = 0; i < op->Inputs().size(); ++i) {
op->Inputs()[i]->Ref(); op->Inputs()[i]->Ref();
inputs.push_back(new TFE_TensorHandle{op->Inputs()[i]}); inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
} }
std::vector<TFE_TensorHandle*> outputs(*num_retvals); std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status; TF_Status status;
@ -1624,7 +1612,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
&attributes, num_retvals, outputs.data(), &status, info_); &attributes, num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) { if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) { for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle); retvals[i] = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(outputs[i]));
retvals[i]->Ref(); retvals[i]->Ref();
TFE_DeleteTensorHandle(outputs[i]); TFE_DeleteTensorHandle(outputs[i]);
} }
@ -1652,7 +1641,7 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
auto custom_device = auto custom_device =
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name); std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = status->status =
context->RegisterCustomDevice(device_name, std::move(custom_device)); context->RegisterCustomDevice(device_name, std::move(custom_device));
} }

View File

@ -57,7 +57,8 @@ extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* h, TF_Status* status) { TFE_TensorHandle* h, TF_Status* status) {
tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle); tensorflow::TensorHandle* handle =
TensorHandleFromInterface(tensorflow::unwrap(h));
const tensorflow::Tensor* tensor; const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor); status->status = handle->Tensor(&tensor);
if (!status->status.ok()) { if (!status->status.ok()) {

View File

@ -19,6 +19,9 @@ limitations under the License.
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h"
@ -34,9 +37,10 @@ using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) { const char* raw_device_name, TF_Status* status) {
if (op_to_reset) { if (op_to_reset) {
op_to_reset->operation->Clear(); tensorflow::AbstractOperationInterface* op =
status->status = tensorflow::unwrap(op_to_reset);
op_to_reset->operation->Reset(op_or_function_name, raw_device_name); op->Clear();
status->status = op->Reset(op_or_function_name, raw_device_name);
} else { } else {
TF_SetStatus(status, TF_INVALID_ARGUMENT, TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr"); "op_to_reset should not be nullptr");
@ -45,13 +49,13 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true); context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false); context->SetShouldStoreGraphs(false);
} }
@ -483,7 +487,7 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
void TFE_ContextSetThreadLocalMirroringPolicy( void TFE_ContextSetThreadLocalMirroringPolicy(
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) { TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalMirroringPolicy( context->SetThreadLocalMirroringPolicy(
static_cast<tensorflow::ContextMirroringPolicy>(policy)); static_cast<tensorflow::ContextMirroringPolicy>(policy));
} }
@ -494,7 +498,7 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TFE_Context* ctx) { TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy()); return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
} }
@ -530,7 +534,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager, TFE_CancellationManager* cancellation_manager,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerOperation* operation = tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation); tensorflow::OperationFromInterface(tensorflow::unwrap(op));
operation->SetCancellationManager( operation->SetCancellationManager(
&cancellation_manager->cancellation_manager); &cancellation_manager->cancellation_manager);
status->status = tensorflow::Status::OK(); status->status = tensorflow::Status::OK();
@ -557,19 +561,19 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetExecutorForThread(executor->executor()); context->SetExecutorForThread(executor->executor());
} }
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return new TFE_Executor(&context->Executor()); return new TFE_Executor(&context->Executor());
} }
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto address_space = tensorflow::DeviceNameUtils::AddressSpace( auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
context->HostCPU()->parsed_name()); context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
@ -585,7 +589,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) { TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(ctx->context); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto* function_def = context->FindFunctionDef(function_name); auto* function_def = context->FindFunctionDef(function_name);
if (function_def == nullptr) { if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound( status->status = tensorflow::errors::NotFound(
@ -611,12 +615,13 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
dimvec[i] = static_cast<tensorflow::int64>(dims[i]); dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
} }
if (ctx == nullptr || ctx->context == nullptr) { if (ctx == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid Context"); status->status = tensorflow::errors::InvalidArgument("Invalid Context");
return nullptr; return nullptr;
} }
tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor( tensorflow::AbstractTensorInterface* t =
tensorflow::unwrap(ctx)->CreateTensor(
static_cast<tensorflow::DataType>(dtype), dimvec); static_cast<tensorflow::DataType>(dtype), dimvec);
if (t == nullptr) { if (t == nullptr) {
@ -630,5 +635,6 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t, TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
TF_Status* status) { TF_Status* status) {
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)}; return tensorflow::wrap(
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
} }

View File

@ -19,13 +19,10 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_context_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export #include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export #include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export #include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_op_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export #include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" // IWYU pragma: export
// TODO(b/154564140): Move this to its own header. This requires splitting // TODO(b/154564140): Move this to its own header. This requires splitting
// c_api_experimental.h // c_api_experimental.h

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/casts.h"
@ -233,7 +234,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU")); ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status); TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle); auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored. // The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr)); ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
} }
@ -245,7 +247,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
// TODO(gjn): Add support for waiting on async local mirrors // TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async) { if (!remote && !async) {
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle); auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored. // The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr)); ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
} }

View File

@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
@ -416,8 +418,10 @@ void TensorHandleSilentCopy(bool async,
hcpu, ctx, gpu_device_name.c_str(), status.get()); hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle); auto cpu_arg =
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle); tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
auto gpu_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device()); auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device)); ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
@ -1346,7 +1350,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) { tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
tensorflow::AttrValueMap attr_values; tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* operation = tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(op->operation); tensorflow::OperationFromInterface(tensorflow::unwrap(op));
operation->Attrs().FillAttrValueMap(&attr_values); operation->Attrs().FillAttrValueMap(&attr_values);
return attr_values; return attr_values;
} }
@ -1482,10 +1486,10 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2}; TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status); TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->operation->OpDef()); CHECK(tensorflow::unwrap(concatOp)->OpDef());
TFE_OpAddInput(concatOp, inputs[0], status); TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->operation->OpDef()) EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef())
<< "Inference context is still present"; << "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status); TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -1590,7 +1594,7 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
// There is currently no API to fetch attributes from an operation, fetching // There is currently no API to fetch attributes from an operation, fetching
// happens only as an implementation detail of custom devices. // happens only as an implementation detail of custom devices.
tensorflow::EagerOperation* operation = tensorflow::EagerOperation* operation =
OperationFromInterface(var_op->operation); OperationFromInterface(tensorflow::unwrap(var_op));
TFE_OpAttrs attributes{&operation->Attrs()}; TFE_OpAttrs attributes{&operation->Attrs()};
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status); TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
@ -1606,7 +1610,7 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
tensorflow::AttrValueMap attr_values; tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* op = tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(copy_op->operation); tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
op->Attrs().FillAttrValueMap(&attr_values); op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type()); EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
@ -1630,7 +1634,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
// There is currently no API to fetch attributes from an operation, fetching // There is currently no API to fetch attributes from an operation, fetching
// happens only as an implementation detail of custom devices. // happens only as an implementation detail of custom devices.
tensorflow::EagerOperation* operation = tensorflow::EagerOperation* operation =
OperationFromInterface(var_op->operation); OperationFromInterface(tensorflow::unwrap(var_op));
TFE_OpAttrs attributes{&operation->Attrs()}; TFE_OpAttrs attributes{&operation->Attrs()};
TF_Buffer* serialized_attr_values = TF_NewBuffer(); TF_Buffer* serialized_attr_values = TF_NewBuffer();
@ -1657,7 +1661,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
tensorflow::AttrValueMap attr_values; tensorflow::AttrValueMap attr_values;
tensorflow::EagerOperation* op = tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(var_op_2->operation); tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2));
op->Attrs().FillAttrValueMap(&attr_values); op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type()); EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());

View File

@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/c/eager/dlpack.h" #include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // from @dlpack #include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h" #include "tensorflow/core/framework/tensor_reference.h"
@ -41,12 +43,12 @@ struct TfDlManagedTensorCtx {
// Gets tensor from eager tensor handle. // Gets tensor from eager tensor handle.
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle"); status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr; return nullptr;
} }
tensorflow::TensorHandle* handle = tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle); tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (handle->Type() != TensorHandle::LOCAL) { if (handle->Type() != TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support ", handle->TypeString(), " tensor"); "DLPack doesn't support ", handle->TypeString(), " tensor");
@ -107,7 +109,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
// Gets DLPack's DLContext from eager tensor handle. // Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) { DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx; DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status); const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type; std::string device_type = parsed_name.type;

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ #ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/context_interface.h" #include "tensorflow/c/eager/context_interface.h"
// Wraps a pointer to a context implementation. // Wraps a pointer to a context implementation.
@ -23,8 +24,12 @@ limitations under the License.
// interface cannot destruct the underlying context object. Instead, call // interface cannot destruct the underlying context object. Instead, call
// TFE_DeleteContext who calls Release() on the context pointer and deletes // TFE_DeleteContext who calls Release() on the context pointer and deletes
// the TFE_Context structure. // the TFE_Context structure.
struct TFE_Context { typedef struct TFE_Context TFE_Context;
tensorflow::AbstractContextInterface* context;
}; namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ #endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_

View File

@ -23,14 +23,15 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways // An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization. // that sometimes do not require serialization.
typedef struct TFE_Context TFE_Context;
typedef struct TFE_Op TFE_Op;
struct TFE_OpAttrs { struct TFE_OpAttrs {
explicit TFE_OpAttrs() : attributes(nullptr) {} explicit TFE_OpAttrs() : attributes(nullptr) {}

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ #ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/operation_interface.h"
// Wraps a pointer to an operation implementation. // Wraps a pointer to an operation implementation.
@ -23,8 +24,13 @@ limitations under the License.
// interface cannot destruct the underlying operation object. Instead, call // interface cannot destruct the underlying operation object. Instead, call
// TFE_DeleteOp who calls Release() on the operation pointer and deletes // TFE_DeleteOp who calls Release() on the operation pointer and deletes
// the TFE_Op structure. // the TFE_Op structure.
struct TFE_Op { typedef struct TFE_Op TFE_Op;
tensorflow::AbstractOperationInterface* operation;
}; namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ #endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ #ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h"
// Wraps a pointer to a tensor handle implementation. // Wraps a pointer to a tensor handle implementation.
@ -23,8 +24,15 @@ limitations under the License.
// interface cannot destruct the underlying handle object. Instead, call // interface cannot destruct the underlying handle object. Instead, call
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes // TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
// the TFE_TensorHandle structure. // the TFE_TensorHandle structure.
struct TFE_TensorHandle { typedef struct TFE_TensorHandle TFE_TensorHandle;
tensorflow::AbstractTensorHandleInterface* handle;
}; namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
TFE_TensorHandle);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
TFE_TensorHandle*);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ #endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_

View File

@ -22,13 +22,6 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
cc_library(
name = "conversion_macros",
hdrs = [
"conversion_macros.h",
],
)
cc_library( cc_library(
name = "concrete_function", name = "concrete_function",
srcs = [ srcs = [
@ -51,6 +44,7 @@ cc_library(
"//tensorflow/c:c_api_macros", "//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/experimental/saved_model/core:concrete_function", "//tensorflow/c/experimental/saved_model/core:concrete_function",
"//tensorflow/c/experimental/saved_model/core:function_metadata", "//tensorflow/c/experimental/saved_model/core:function_metadata",
], ],
@ -93,7 +87,7 @@ cc_library(
"concrete_function_type.h", "concrete_function_type.h",
], ],
deps = [ deps = [
":conversion_macros", "//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:concrete_function", "//tensorflow/c/experimental/saved_model/core:concrete_function",
], ],
) )
@ -123,7 +117,7 @@ cc_library(
"function_metadata_type.h", "function_metadata_type.h",
], ],
deps = [ deps = [
":conversion_macros", "//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:function_metadata", "//tensorflow/c/experimental/saved_model/core:function_metadata",
], ],
) )

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
@ -24,7 +25,8 @@ limitations under the License.
extern "C" { extern "C" {
TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) { TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
return tensorflow::wrap(&tensorflow::unwrap(func)->GetFunctionMetadata()); return tensorflow::wrap(const_cast<tensorflow::FunctionMetadata*>(
&tensorflow::unwrap(func)->GetFunctionMetadata()));
} }
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) { TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
@ -34,7 +36,7 @@ TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
} }
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) { TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
return new TFE_Op{tensorflow::unwrap(func)->GetCallOp()}; return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp());
} }
} // end extern "C" } // end extern "C"

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
// Internal structures used by the SavedModel C API. These are likely to change // Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on. // and should not be depended on.

View File

@ -1,28 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline wrapper *wrap(const cpp_impl *i) { \
return reinterpret_cast<wrapper *>(const_cast<cpp_impl *>(i)); \
}
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
typedef struct TF_FunctionMetadata TF_FunctionMetadata; typedef struct TF_FunctionMetadata TF_FunctionMetadata;

View File

@ -36,7 +36,8 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
std::string saved_model_dir(dirname); std::string saved_model_dir(dirname);
std::unique_ptr<tensorflow::SavedModelAPI> result = std::unique_ptr<tensorflow::SavedModelAPI> result =
ctx->context->LoadSavedModelAPI(dirname, absl::nullopt, &status->status); tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
&status->status);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
@ -54,7 +55,7 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
} }
std::unique_ptr<tensorflow::SavedModelAPI> result = std::unique_ptr<tensorflow::SavedModelAPI> result =
ctx->context->LoadSavedModelAPI(dirname, std::move(tagset), tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
&status->status); &status->status);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;

View File

@ -903,7 +903,8 @@ cc_library(
":safe_ptr", ":safe_ptr",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -1007,7 +1008,8 @@ cc_library(
":safe_ptr", ":safe_ptr",
"//tensorflow/c:tensor_interface", "//tensorflow/c:tensor_interface",
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//third_party/python_runtime:headers", "//third_party/python_runtime:headers",

View File

@ -40,6 +40,9 @@ cc_library(
"//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:dlpack", "//tensorflow/c/eager:dlpack",
"//tensorflow/c/eager:tape", "//tensorflow/c/eager:tape",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/hash/hash.h" #include "absl/hash/hash.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -41,7 +41,7 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
PyObject* value, tensorflow::DataType dtype, PyObject* value, tensorflow::DataType dtype,
absl::string_view device_name) const { absl::string_view device_name) const {
CHECK_NOTNULL(value); CHECK_NOTNULL(value);
const auto& it = cache.find(Key{PyObjectPtr{value}, dtype, device_name}); const auto it = cache.find(Key{PyObjectPtr{value}, dtype, device_name});
if (it == cache.end()) { if (it == cache.end()) {
scalar_cache_misses->GetCell()->IncrementBy(1); scalar_cache_misses->GetCell()->IncrementBy(1);
return nullptr; return nullptr;
@ -49,7 +49,7 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
scalar_cache_hits->GetCell()->IncrementBy(1); scalar_cache_hits->GetCell()->IncrementBy(1);
auto* h = it->second; auto* h = it->second;
return new TFE_TensorHandle{h->handle->Copy()}; return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
} }
void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype, void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
@ -57,7 +57,7 @@ void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
TFE_TensorHandle* h) { TFE_TensorHandle* h) {
Py_INCREF(value); Py_INCREF(value);
cache.emplace(Key{PyObjectPtr{value}, dtype, device_name}, cache.emplace(Key{PyObjectPtr{value}, dtype, device_name},
new TFE_TensorHandle{h->handle->Copy()}); tensorflow::wrap(tensorflow::unwrap(h)->Copy()));
} }
void TFE_TensorHandleCache::Clear() { void TFE_TensorHandleCache::Clear() {

View File

@ -379,7 +379,7 @@ void TFE_Py_EnableInteractivePythonLogging();
// Py_None. // Py_None.
// //
// This function is not thread-safe. // This function is not thread-safe.
PyObject* TFE_Py_SetEagerContext(PyObject* python_context); PyObject* TFE_Py_SetEagerContext(PyObject* py_context);
// Returns the current eager Context object (defined in eager/context.py) // Returns the current eager Context object (defined in eager/context.py)
// that was last set using TFE_Py_SetEagerContext. // that was last set using TFE_Py_SetEagerContext.

View File

@ -24,6 +24,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tape.h" #include "tensorflow/c/eager/tape.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
@ -80,9 +83,10 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) { const char* raw_device_name, TF_Status* status) {
auto op = ReleaseThreadLocalOp(ctx); auto op = ReleaseThreadLocalOp(ctx);
if (!op) { if (!op) {
op.reset(new TFE_Op{ctx->context->CreateOperation()}); op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
} }
status->status = op->operation->Reset(op_or_function_name, raw_device_name); status->status =
tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
if (!status->status.ok()) { if (!status->status.ok()) {
op.reset(); op.reset();
} }
@ -91,7 +95,7 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
void ReturnOp(TFE_Context* ctx, TFE_Op* op) { void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
if (op) { if (op) {
op->operation->Clear(); tensorflow::unwrap(op)->Clear();
thread_local_eager_operation_map[ctx].reset(op); thread_local_eager_operation_map[ctx].reset(op);
} }
} }
@ -1500,7 +1504,7 @@ static PyTypeObject TFE_Py_Tape_Type = {
sizeof(TFE_Py_Tape), /* tp_basicsize */ sizeof(TFE_Py_Tape), /* tp_basicsize */
0, /* tp_itemsize */ 0, /* tp_itemsize */
&TFE_Py_Tape_Delete, /* tp_dealloc */ &TFE_Py_Tape_Delete, /* tp_dealloc */
0, /* tp_print */ nullptr, /* tp_print */
nullptr, /* tp_getattr */ nullptr, /* tp_getattr */
nullptr, /* tp_setattr */ nullptr, /* tp_setattr */
nullptr, /* tp_reserved */ nullptr, /* tp_reserved */
@ -1538,7 +1542,7 @@ static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */ sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */
0, /* tp_itemsize */ 0, /* tp_itemsize */
&TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */ &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */
0, /* tp_print */ nullptr, /* tp_print */
nullptr, /* tp_getattr */ nullptr, /* tp_getattr */
nullptr, /* tp_setattr */ nullptr, /* tp_setattr */
nullptr, /* tp_reserved */ nullptr, /* tp_reserved */
@ -1573,7 +1577,7 @@ static PyTypeObject TFE_Py_VariableWatcher_Type = {
sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */ sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
0, /* tp_itemsize */ 0, /* tp_itemsize */
&TFE_Py_VariableWatcher_Delete, /* tp_dealloc */ &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
0, /* tp_print */ nullptr, /* tp_print */
nullptr, /* tp_getattr */ nullptr, /* tp_getattr */
nullptr, /* tp_setattr */ nullptr, /* tp_setattr */
nullptr, /* tp_reserved */ nullptr, /* tp_reserved */
@ -1990,21 +1994,22 @@ bool ListContainsNone(PyObject* list) {
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) { if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor); tensorflow::AbstractTensorHandleInterface* handle =
tensorflow::unwrap(EagerTensor_Handle(tensor));
tensorflow::int64 id = PyEagerTensor_ID(tensor); tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::DataType dtype = tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(t->handle->DataType()); static_cast<tensorflow::DataType>(handle->DataType());
if (dtype == tensorflow::DT_VARIANT) { if (dtype == tensorflow::DT_VARIANT) {
return PyTapeTensor(id, dtype, tensor); return PyTapeTensor(id, dtype, tensor);
} }
tensorflow::TensorShape tensor_shape; tensorflow::TensorShape tensor_shape;
int num_dims; int num_dims;
tensorflow::Status status = t->handle->NumDims(&num_dims); tensorflow::Status status = handle->NumDims(&num_dims);
if (status.ok()) { if (status.ok()) {
for (int i = 0; i < num_dims; ++i) { for (int i = 0; i < num_dims; ++i) {
tensorflow::int64 dim_size; tensorflow::int64 dim_size;
status = t->handle->Dim(i, &dim_size); status = handle->Dim(i, &dim_size);
if (!status.ok()) break; if (!status.ok()) break;
tensor_shape.AddDim(dim_size); tensor_shape.AddDim(dim_size);
} }
@ -3511,7 +3516,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
return nullptr; return nullptr;
} }
const tensorflow::OpDef* op_def = op->operation->OpDef(); const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
if (op_def == nullptr) return nullptr; if (op_def == nullptr) return nullptr;
if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) { if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
@ -3850,14 +3855,15 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
bool include_tensor_ranks_only, bool include_tensor_ranks_only,
EncodeResult* result) { EncodeResult* result) {
if (EagerTensor_CheckExact(arg)) { if (EagerTensor_CheckExact(arg)) {
TFE_TensorHandle* t = EagerTensor_Handle(arg); tensorflow::AbstractTensorHandleInterface* handle =
tensorflow::unwrap(EagerTensor_Handle(arg));
absl::StrAppend(&result->str, kDType, absl::StrAppend(&result->str, kDType,
static_cast<tensorflow::DataType>(t->handle->DataType())); static_cast<tensorflow::DataType>(handle->DataType()));
absl::StrAppend(&result->str, kShape); absl::StrAppend(&result->str, kShape);
int num_dims; int num_dims;
tensorflow::Status status = t->handle->NumDims(&num_dims); tensorflow::Status status = handle->NumDims(&num_dims);
if (!status.ok()) return status; if (!status.ok()) return status;
if (include_tensor_ranks_only) { if (include_tensor_ranks_only) {
@ -3865,7 +3871,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
} else { } else {
for (int i = 0; i < num_dims; ++i) { for (int i = 0; i < num_dims; ++i) {
tensorflow::int64 dim_size; tensorflow::int64 dim_size;
status = t->handle->Dim(i, &dim_size); status = handle->Dim(i, &dim_size);
if (!status.ok()) return status; if (!status.ok()) return status;
absl::StrAppend(&result->str, dim_size, kShapeDelim); absl::StrAppend(&result->str, dim_size, kShapeDelim);
} }

View File

@ -26,7 +26,8 @@ limitations under the License.
#include "numpy/arrayobject.h" #include "numpy/arrayobject.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
@ -95,8 +96,8 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
if (call->eager) { if (call->eager) {
Tensor t = call->ins[i]; Tensor t = call->ins[i];
arg = EagerTensorFromHandle( arg = EagerTensorFromHandle(
new TFE_TensorHandle{TensorHandle::CreateLocalHandle( tensorflow::wrap(TensorHandle::CreateLocalHandle(
std::move(t), ctx->CanonicalDevice(device), nullptr, ctx)}); std::move(t), ctx->CanonicalDevice(device), nullptr, ctx)));
if (arg == nullptr) { if (arg == nullptr) {
Py_DECREF(lst); Py_DECREF(lst);
return errors::Internal("Unable to procure EagerTensor from Tensor."); return errors::Internal("Unable to procure EagerTensor from Tensor.");
@ -146,7 +147,7 @@ tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
const Device* expected_device, const Device* expected_device,
const Tensor** output_tensor) { const Tensor** output_tensor) {
tensorflow::TensorHandle* handle = tensorflow::TensorHandleFromInterface( tensorflow::TensorHandle* handle = tensorflow::TensorHandleFromInterface(
EagerTensor_Handle(eager_tensor)->handle); tensorflow::unwrap(EagerTensor_Handle(eager_tensor)));
if (VariantDeviceIsCustom(handle->device())) { if (VariantDeviceIsCustom(handle->device())) {
return errors::Unimplemented( return errors::Unimplemented(
"Custom devices are currently not supported with PyFuncs."); "Custom devices are currently not supported with PyFuncs.");
@ -196,7 +197,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer( TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
CHECK_NE(ctx, nullptr); CHECK_NE(ctx, nullptr);
EagerContext* context = ContextFromInterface(ctx->context); EagerContext* context = ContextFromInterface(tensorflow::unwrap(ctx));
TF_RETURN_IF_ERROR(MakeArgTuple(call, context, &args)); TF_RETURN_IF_ERROR(MakeArgTuple(call, context, &args));
new_executor.reset(new EagerExecutor(call->eager_async)); new_executor.reset(new EagerExecutor(call->eager_async));
old_executor = &context->Executor(); old_executor = &context->Executor();
@ -236,7 +237,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
if (new_executor != nullptr) { if (new_executor != nullptr) {
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer( TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
PyObject_GetAttrString(trampoline, "_ctx"), nullptr)); PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
EagerContext* context = ContextFromInterface(ctx->context); EagerContext* context = ContextFromInterface(tensorflow::unwrap(ctx));
s.Update(new_executor->WaitForAllPendingNodes()); s.Update(new_executor->WaitForAllPendingNodes());
context->SetExecutorForThread(old_executor); context->SetExecutorForThread(old_executor);
} }

View File

@ -15,7 +15,8 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_seq_tensor.h" #include "tensorflow/python/lib/core/py_seq_tensor.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tensor_interface.h" #include "tensorflow/c/tensor_interface.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
@ -305,7 +306,7 @@ struct Converter {
} }
} }
} }
*h = new TFE_TensorHandle{ctx->context->CreateLocalHandle(t)}; *h = tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(t));
t->Release(); t->Release();
return Status::OK(); return Status::OK();
} }
@ -316,12 +317,12 @@ struct Converter {
template <> template <>
struct ConverterTraits<int64> { struct ConverterTraits<int64> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, int64 value) { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, int64 value) {
return ctx->context->CreateInt64Scalar(value); return tensorflow::unwrap(ctx)->CreateInt64Scalar(value);
} }
static AbstractTensorInterface* CreateTensor( static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) { TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_INT64, dim_sizes); return tensorflow::unwrap(ctx)->CreateTensor(DT_INT64, dim_sizes);
} }
static const char* ConvertScalar(PyObject* v, int64* out) { static const char* ConvertScalar(PyObject* v, int64* out) {
@ -356,12 +357,12 @@ typedef Converter<int64> Int64Converter;
template <> template <>
struct ConverterTraits<uint64> { struct ConverterTraits<uint64> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, uint64 value) { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, uint64 value) {
return ctx->context->CreateUint64Scalar(value); return tensorflow::unwrap(ctx)->CreateUint64Scalar(value);
} }
static AbstractTensorInterface* CreateTensor( static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) { TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_UINT64, dim_sizes); return tensorflow::unwrap(ctx)->CreateTensor(DT_UINT64, dim_sizes);
} }
static const char* ConvertScalar(PyObject* v, uint64* out) { static const char* ConvertScalar(PyObject* v, uint64* out) {
@ -393,12 +394,12 @@ typedef Converter<uint64> UInt64Converter;
template <> template <>
struct ConverterTraits<int32> { struct ConverterTraits<int32> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, int32 value) { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, int32 value) {
return ctx->context->CreateInt32Scalar(value); return tensorflow::unwrap(ctx)->CreateInt32Scalar(value);
} }
static AbstractTensorInterface* CreateTensor( static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) { TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_INT32, dim_sizes); return tensorflow::unwrap(ctx)->CreateTensor(DT_INT32, dim_sizes);
} }
static const char* ConvertScalar(PyObject* v, int32* out) { static const char* ConvertScalar(PyObject* v, int32* out) {
@ -500,12 +501,12 @@ static const char* ConvertOneFloat(PyObject* v, T* out) {
template <> template <>
struct ConverterTraits<float> { struct ConverterTraits<float> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, float value) { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, float value) {
return ctx->context->CreateFloatScalar(value); return tensorflow::unwrap(ctx)->CreateFloatScalar(value);
} }
static AbstractTensorInterface* CreateTensor( static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) { TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_FLOAT, dim_sizes); return tensorflow::unwrap(ctx)->CreateTensor(DT_FLOAT, dim_sizes);
} }
static const char* ConvertScalar(PyObject* v, float* out) { static const char* ConvertScalar(PyObject* v, float* out) {
@ -516,12 +517,12 @@ struct ConverterTraits<float> {
template <> template <>
struct ConverterTraits<double> { struct ConverterTraits<double> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, double value) { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, double value) {
return ctx->context->CreateDoubleScalar(value); return tensorflow::unwrap(ctx)->CreateDoubleScalar(value);
} }
static AbstractTensorInterface* CreateTensor( static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) { TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_DOUBLE, dim_sizes); return tensorflow::unwrap(ctx)->CreateTensor(DT_DOUBLE, dim_sizes);
} }
static const char* ConvertScalar(PyObject* v, double* out) { static const char* ConvertScalar(PyObject* v, double* out) {
@ -536,12 +537,12 @@ template <>
struct ConverterTraits<Eigen::half> { struct ConverterTraits<Eigen::half> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, static AbstractTensorInterface* CreateScalar(TFE_Context* ctx,
Eigen::half value) { Eigen::half value) {
return ctx->context->CreateHalfScalar(value); return tensorflow::unwrap(ctx)->CreateHalfScalar(value);
} }
static AbstractTensorInterface* CreateTensor( static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) { TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_HALF, dim_sizes); return tensorflow::unwrap(ctx)->CreateTensor(DT_HALF, dim_sizes);
} }
static const char* ConvertScalar(PyObject* v, Eigen::half* out) { static const char* ConvertScalar(PyObject* v, Eigen::half* out) {
@ -557,12 +558,12 @@ template <>
struct ConverterTraits<tstring> { struct ConverterTraits<tstring> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, static AbstractTensorInterface* CreateScalar(TFE_Context* ctx,
tstring value) { tstring value) {
return ctx->context->CreateStringScalar(value); return tensorflow::unwrap(ctx)->CreateStringScalar(value);
} }
static AbstractTensorInterface* CreateTensor( static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) { TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_STRING, dim_sizes); return tensorflow::unwrap(ctx)->CreateTensor(DT_STRING, dim_sizes);
} }
static const char* ConvertScalar(PyObject* v, tstring* out) { static const char* ConvertScalar(PyObject* v, tstring* out) {
@ -624,12 +625,12 @@ template <>
struct ConverterTraits<complex128> { struct ConverterTraits<complex128> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, static AbstractTensorInterface* CreateScalar(TFE_Context* ctx,
complex128 value) { complex128 value) {
return ctx->context->CreateComplex128Scalar(value); return tensorflow::unwrap(ctx)->CreateComplex128Scalar(value);
} }
static AbstractTensorInterface* CreateTensor( static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) { TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_COMPLEX128, dim_sizes); return tensorflow::unwrap(ctx)->CreateTensor(DT_COMPLEX128, dim_sizes);
} }
static const char* ConvertScalar(PyObject* v, complex128* out) { static const char* ConvertScalar(PyObject* v, complex128* out) {
@ -652,12 +653,12 @@ typedef Converter<complex128> Complex128Converter;
template <> template <>
struct ConverterTraits<bool> { struct ConverterTraits<bool> {
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, bool value) { static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, bool value) {
return ctx->context->CreateBoolScalar(value); return tensorflow::unwrap(ctx)->CreateBoolScalar(value);
} }
static AbstractTensorInterface* CreateTensor( static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) { TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateTensor(DT_BOOL, dim_sizes); return tensorflow::unwrap(ctx)->CreateTensor(DT_BOOL, dim_sizes);
} }
static const char* ConvertScalar(PyObject* v, bool* out) { static const char* ConvertScalar(PyObject* v, bool* out) {
@ -692,7 +693,7 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
} }
TensorInterface t(std::move(tensor)); TensorInterface t(std::move(tensor));
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(&t)}; return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t));
} }
} // namespace } // namespace
@ -877,7 +878,7 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
TensorShape(state.inferred_shape)); TensorShape(state.inferred_shape));
TensorInterface t(std::move(tensor)); TensorInterface t(std::move(tensor));
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(&t)}; return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t));
} }
default: default:

View File

@ -270,7 +270,6 @@ static py::object TFE_ClearScalarCache() {
// are only assigning this to functions that return opaque types. // are only assigning this to functions that return opaque types.
PYBIND11_MODULE(_pywrap_tfe, m) { PYBIND11_MODULE(_pywrap_tfe, m) {
py::class_<TFE_Context> TFE_Context_class(m, "TFE_Context");
py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor"); py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m, py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
"TFE_ContextOptions"); "TFE_ContextOptions");
@ -760,7 +759,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TFE_ContextStartStep", [](py::handle& o) { m.def("TFE_ContextStartStep", [](py::handle& o) {
TFE_ContextStartStep(tensorflow::InputTFE_Context(o.ptr())); TFE_ContextStartStep(tensorflow::InputTFE_Context(o.ptr()));
}); });
m.def("TFE_ContextEndStep", &TFE_ContextEndStep); m.def("TFE_ContextEndStep", [](py::handle& o) {
TFE_ContextEndStep(tensorflow::InputTFE_Context(o.ptr()));
});
m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) { m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr())); return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
}); });