Avoid pointer indirection in handle, context & op
Since the struct lifetime is bound to the wrapped pointer this is fine. PiperOrigin-RevId: 308941521 Change-Id: I0604fff4fcba6a03cc4a2242ab9f182fbfbf8bae
This commit is contained in:
parent
510f0f9a6c
commit
9ec6997dfb
|
@ -334,6 +334,9 @@ tf_cuda_library(
|
|||
":checkpoint_reader",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -729,3 +732,11 @@ tf_cuda_library(
|
|||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "conversion_macros",
|
||||
hdrs = [
|
||||
"conversion_macros.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -21,6 +21,9 @@ limitations under the License.
|
|||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
|
@ -686,8 +689,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
|
|||
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
|
||||
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_TensorHandle{
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -708,7 +710,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
|||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
|
@ -822,14 +824,13 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
|||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
node_def.set_name(tfe_op->operation->Name());
|
||||
node_def.set_op(tfe_op->operation->Name());
|
||||
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
|
||||
node_def.set_name(op->Name());
|
||||
node_def.set_op(op->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
node_def.add_input("dummy_input");
|
||||
}
|
||||
OperationFromInterface(tfe_op->operation)
|
||||
->Attrs()
|
||||
.FillAttrValueMap(node_def.mutable_attr());
|
||||
OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
|
||||
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
status->status =
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
#define TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
|
||||
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
|
||||
inline cpp_impl *unwrap(wrapper *w) { \
|
||||
return reinterpret_cast<cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline const cpp_impl *unwrap(const wrapper *w) { \
|
||||
return reinterpret_cast<const cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); }
|
||||
|
||||
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_
|
|
@ -50,7 +50,6 @@ tf_cuda_library(
|
|||
":tfe_tensor_debug_info_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"//tensorflow/c:c_api",
|
||||
|
@ -110,13 +109,10 @@ filegroup(
|
|||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_context_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
"tfe_op_attrs_internal.h",
|
||||
"tfe_op_internal.h",
|
||||
"tfe_tensor_debug_info_internal.h",
|
||||
"tfe_tensorhandle_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
|
@ -205,6 +201,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":context_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -249,8 +246,6 @@ cc_library(
|
|||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tfe_context_internal",
|
||||
":tfe_op_internal",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
|
@ -265,6 +260,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -287,6 +283,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -327,6 +324,8 @@ tf_cuda_cc_test(
|
|||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
@ -351,6 +350,7 @@ tf_cuda_cc_test(
|
|||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -384,6 +384,9 @@ tf_cuda_library(
|
|||
"//conditions:default": [
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
":tfe_context_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
@ -548,8 +551,9 @@ cc_library(
|
|||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
|
|
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
|
@ -34,6 +33,9 @@ limitations under the License.
|
|||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#include "tensorflow/c/eager/c_api_tfrt.h"
|
||||
|
@ -298,7 +300,7 @@ tensorflow::Status CreateRemoteContexts(
|
|||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
|
@ -383,7 +385,7 @@ tensorflow::Status UpdateRemoteContexts(
|
|||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
|
@ -464,7 +466,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
|
@ -684,7 +686,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
|||
if (opts->use_tfrt) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_Context{new tfrt::ContextInterface()};
|
||||
return tensorflow::wrap(new tfrt::ContextInterface());
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
|
@ -701,14 +703,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
|||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
|
@ -719,14 +721,14 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
|||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return new TFE_Context{new tensorflow::EagerContext(
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator())};
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
|
@ -734,22 +736,19 @@ void TFE_DeleteContext(TFE_Context* ctx) {
|
|||
return;
|
||||
}
|
||||
|
||||
// context->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
ctx->context->Release();
|
||||
|
||||
delete ctx;
|
||||
// ctx->RefCountIsOne() should be true here.
|
||||
tensorflow::unwrap(ctx)->Release();
|
||||
}
|
||||
|
||||
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
TF_DeviceList* l = new TF_DeviceList;
|
||||
ctx->context->ListDevices(&l->response);
|
||||
tensorflow::unwrap(ctx)->ListDevices(&l->response);
|
||||
return l;
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->ClearCachesAndThreadExecutors();
|
||||
}
|
||||
|
||||
|
@ -772,7 +771,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
|||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
const string remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
|
@ -781,7 +780,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
|||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
|
@ -803,7 +802,7 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
|||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::ServerDef server_def;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
if (!server_def.ParseFromArray(proto, proto_len)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
|
@ -833,7 +832,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||
return false;
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
static_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
|
||||
|
@ -889,7 +888,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
|||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->SyncExecutors();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
@ -897,7 +896,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
|||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalDevicePlacementPolicy(
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||
}
|
||||
|
@ -908,7 +907,7 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
|||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||
context->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
@ -918,8 +917,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
|||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
return new TFE_TensorHandle{
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
||||
}
|
||||
|
||||
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
|
@ -927,84 +925,84 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
|||
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
if (h->handle) {
|
||||
h->handle->Release();
|
||||
if (h) {
|
||||
tensorflow::unwrap(h)->Release();
|
||||
}
|
||||
delete h;
|
||||
}
|
||||
|
||||
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
||||
return static_cast<TF_DataType>(h->handle->DataType());
|
||||
return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
|
||||
}
|
||||
|
||||
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int num_dims = -1;
|
||||
status->status = h->handle->NumDims(&num_dims);
|
||||
status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
|
||||
return num_dims;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64 num_elements = -1;
|
||||
status->status = h->handle->NumElements(&num_elements);
|
||||
status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
|
||||
return num_elements;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64 dim = -1;
|
||||
status->status = h->handle->Dim(dim_index, &dim);
|
||||
status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
|
||||
return dim;
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
return h->handle->DeviceName(&status->status);
|
||||
return tensorflow::unwrap(h)->DeviceName(&status->status);
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
return h->handle->BackingDeviceName(&status->status);
|
||||
return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new TFE_TensorHandle{h->handle->Copy()};
|
||||
return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
|
||||
}
|
||||
|
||||
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::AbstractTensorInterface* t = h->handle->Resolve(&status->status);
|
||||
tensorflow::AbstractTensorInterface* t =
|
||||
tensorflow::unwrap(h)->Resolve(&status->status);
|
||||
if (t == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1013,12 +1011,12 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
|||
}
|
||||
|
||||
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
const tensorflow::Tensor* t;
|
||||
status->status = handle->Tensor(&t);
|
||||
|
@ -1054,7 +1052,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device = nullptr;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::CustomDevice* custom_device = nullptr;
|
||||
if (!status->status.ok()) {
|
||||
|
@ -1080,11 +1078,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||
tensorflow::TensorShape(dimvec), buf);
|
||||
buf->Unref();
|
||||
if (custom_device == nullptr) {
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), device, device, context)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), device, device, context));
|
||||
} else {
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1093,12 +1091,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||
// bytes of the memory pointed to by the device pointer returned above.
|
||||
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return 0;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDeviceMemorySize may not be called on a ",
|
||||
|
@ -1115,12 +1113,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
|||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op> new_op(new TFE_Op{ctx->context->CreateOperation()});
|
||||
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
|
||||
tensorflow::AbstractOperationInterface* new_op =
|
||||
tensorflow::unwrap(ctx)->CreateOperation();
|
||||
status->status = new_op->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
new_op.reset();
|
||||
new_op->Release();
|
||||
new_op = nullptr;
|
||||
}
|
||||
return new_op.release();
|
||||
return tensorflow::wrap(new_op);
|
||||
}
|
||||
|
||||
void TFE_DeleteOp(TFE_Op* op) {
|
||||
|
@ -1128,24 +1128,20 @@ void TFE_DeleteOp(TFE_Op* op) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (op->operation) {
|
||||
op->operation->Release();
|
||||
}
|
||||
|
||||
delete op;
|
||||
tensorflow::unwrap(op)->Release();
|
||||
}
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
status->status = op->operation->SetDeviceName(device_name);
|
||||
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
|
||||
}
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
return op->operation->DeviceName().c_str();
|
||||
return tensorflow::unwrap(op)->DeviceName().c_str();
|
||||
}
|
||||
|
||||
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Status s = op->operation->SetUseXla(enable);
|
||||
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
|
||||
}
|
||||
|
@ -1156,18 +1152,13 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
|||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
status->status = op->operation->AddInput(input->handle);
|
||||
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
|
||||
}
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
|
||||
num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
handles[i] = inputs[i]->handle;
|
||||
}
|
||||
status->status =
|
||||
op->operation->AddInputList({handles.data(), handles.size()});
|
||||
status->status = tensorflow::unwrap(op)->AddInputList(
|
||||
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
|
@ -1175,8 +1166,8 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
|||
TF_AttrType ret = TF_ATTR_INT;
|
||||
const tensorflow::AttrTypeMap* attr_types_;
|
||||
bool is_function;
|
||||
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
|
||||
&attr_types_, &is_function);
|
||||
status->status = tensorflow::AttrTypeMapForOp(
|
||||
tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
|
||||
if (!status->status.ok()) {
|
||||
return ret;
|
||||
}
|
||||
|
@ -1202,7 +1193,7 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
|
|||
|
||||
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
||||
size_t length) {
|
||||
auto s = op->operation->SetAttrString(
|
||||
auto s = tensorflow::unwrap(op)->SetAttrString(
|
||||
attr_name, static_cast<const char*>(value), length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
|
@ -1210,29 +1201,30 @@ void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
|||
}
|
||||
|
||||
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
||||
auto s = op->operation->SetAttrInt(attr_name, value);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
||||
auto s = op->operation->SetAttrFloat(attr_name, value);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
||||
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
|
||||
(value == 0) ? false : true);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
||||
auto s = op->operation->SetAttrType(attr_name,
|
||||
static_cast<tensorflow::DataType>(value));
|
||||
auto s = tensorflow::unwrap(op)->SetAttrType(
|
||||
attr_name, static_cast<tensorflow::DataType>(value));
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
|
@ -1240,12 +1232,14 @@ void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
|||
|
||||
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
||||
const int num_dims, TF_Status* out_status) {
|
||||
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
|
||||
out_status->status =
|
||||
tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op* value) {
|
||||
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunction(
|
||||
attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
|
@ -1253,7 +1247,7 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
|||
|
||||
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
||||
const char* data, size_t length) {
|
||||
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
|
@ -1264,14 +1258,14 @@ void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
|||
tensorflow::Tensor t;
|
||||
status->status = TF_TensorToTensor(tensor, &t);
|
||||
tensorflow::TensorInterface interface(t);
|
||||
status->status = op->operation->SetAttrTensor(attr_name, &interface);
|
||||
status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
const void* const* values, const size_t* lengths,
|
||||
int num_values) {
|
||||
auto s =
|
||||
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
|
||||
num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
|
@ -1279,7 +1273,8 @@ void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
|||
|
||||
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
|
||||
auto s =
|
||||
tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
|
@ -1287,7 +1282,8 @@ void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
|||
|
||||
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
|
||||
auto s =
|
||||
tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
|
@ -1295,7 +1291,7 @@ void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
|||
|
||||
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
auto s = op->operation->SetAttrTypeList(
|
||||
auto s = tensorflow::unwrap(op)->SetAttrTypeList(
|
||||
attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
|
||||
num_values);
|
||||
if (!s.ok()) {
|
||||
|
@ -1305,7 +1301,8 @@ void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
|||
|
||||
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
const unsigned char* values, int num_values) {
|
||||
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
|
||||
auto s =
|
||||
tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
|
@ -1314,19 +1311,14 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
|||
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, TF_Status* out_status) {
|
||||
out_status->status =
|
||||
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
|
||||
attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op** value, int num_values) {
|
||||
absl::FixedArray<const tensorflow::AbstractOperationInterface*> values(
|
||||
num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
values[i] = value[i]->operation;
|
||||
}
|
||||
auto s = op->operation->SetAttrFunctionList(attr_name,
|
||||
{values.data(), values.size()});
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
|
||||
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
|
@ -1341,12 +1333,13 @@ void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
|
|||
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
|
||||
return;
|
||||
}
|
||||
if (op == nullptr || op->operation == nullptr) {
|
||||
if (op == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Got a null or uninitialized `op` argument");
|
||||
return;
|
||||
}
|
||||
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
|
||||
operation->MutableAttrs()->Set(attr_name, attr_value);
|
||||
}
|
||||
|
||||
|
@ -1354,7 +1347,7 @@ TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
|||
const char* input_name,
|
||||
TF_Status* status) {
|
||||
int ret = -1;
|
||||
status->status = op->operation->InputLength(input_name, &ret);
|
||||
status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -1362,36 +1355,29 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
|||
const char* output_name,
|
||||
TF_Status* status) {
|
||||
int ret = -1;
|
||||
status->status = op->operation->OutputLength(output_name, &ret);
|
||||
status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<tensorflow::AbstractTensorHandleInterface*> handles(
|
||||
*num_retvals);
|
||||
status->status = op->operation->Execute(absl::MakeSpan(handles), num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new TFE_TensorHandle{handles[i]};
|
||||
}
|
||||
status->status = tensorflow::unwrap(op)->Execute(
|
||||
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TFE_Context* ctx,
|
||||
const char* device_name,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* result = ctx->context->CopyTensorHandleToDevice(h->handle, device_name,
|
||||
&status->status);
|
||||
auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice(
|
||||
tensorflow::unwrap(h), device_name, &status->status);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{result};
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1406,39 +1392,39 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
|||
return;
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function_def);
|
||||
}
|
||||
|
||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function->fdef);
|
||||
}
|
||||
|
||||
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->RemoveFunction(name);
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return context->FindFunctionDef(name) != nullptr;
|
||||
}
|
||||
|
||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
|
@ -1446,13 +1432,13 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
|||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
TF_Status* status) {
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(t)};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
|
||||
}
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->Executor().WaitForAllPendingNodes();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*context->MetadataMu());
|
||||
|
@ -1475,20 +1461,21 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
|||
|
||||
void TFE_ContextStartStep(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->StartStep();
|
||||
}
|
||||
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->EndStep();
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
attrs->attributes->FillAttrValueMap(&m);
|
||||
tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(op));
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
for (const auto& attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
|
@ -1576,33 +1563,34 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
|||
const string& name() override { return name_; }
|
||||
|
||||
tensorflow::Status CopyTensorToDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
tensorflow::TensorHandle* handle,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{tensor};
|
||||
handle->Ref();
|
||||
TF_Status status;
|
||||
TFE_TensorHandle* result_handle =
|
||||
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
|
||||
tensor_handle.handle->Release();
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
|
||||
context_, tensorflow::wrap(handle), &status, info_);
|
||||
handle->Release();
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
|
||||
*result = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(result_handle));
|
||||
(*result)->Ref();
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status CopyTensorFromDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
tensorflow::TensorHandle* handle,
|
||||
const tensorflow::string& target_device_name,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
TF_Status status;
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{tensor};
|
||||
handle->Ref();
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
tensor_handle.handle->Release();
|
||||
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
|
||||
info_);
|
||||
handle->Release();
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
|
||||
*result = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(result_handle));
|
||||
(*result)->Ref();
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
return status.status;
|
||||
|
@ -1615,7 +1603,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
|||
inputs.reserve(op->Inputs().size());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
op->Inputs()[i]->Ref();
|
||||
inputs.push_back(new TFE_TensorHandle{op->Inputs()[i]});
|
||||
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
|
@ -1624,7 +1612,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
|||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle);
|
||||
retvals[i] = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(outputs[i]));
|
||||
retvals[i]->Ref();
|
||||
TFE_DeleteTensorHandle(outputs[i]);
|
||||
}
|
||||
|
@ -1652,7 +1641,7 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
|||
auto custom_device =
|
||||
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
|
|
@ -57,7 +57,8 @@ extern "C" {
|
|||
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
|
||||
tensorflow::TensorHandle* handle =
|
||||
TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = handle->Tensor(&tensor);
|
||||
if (!status->status.ok()) {
|
||||
|
|
|
@ -19,6 +19,9 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
|
@ -34,9 +37,10 @@ using tensorflow::string;
|
|||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
op_to_reset->operation->Clear();
|
||||
status->status =
|
||||
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||
tensorflow::AbstractOperationInterface* op =
|
||||
tensorflow::unwrap(op_to_reset);
|
||||
op->Clear();
|
||||
status->status = op->Reset(op_or_function_name, raw_device_name);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
|
@ -45,13 +49,13 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
|||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
|
@ -483,7 +487,7 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
|
|||
void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalMirroringPolicy(
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(policy));
|
||||
}
|
||||
|
@ -494,7 +498,7 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
|
|||
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
|
||||
}
|
||||
|
||||
|
@ -530,7 +534,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op,
|
|||
TFE_CancellationManager* cancellation_manager,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerOperation* operation =
|
||||
tensorflow::OperationFromInterface(op->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
|
||||
operation->SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
status->status = tensorflow::Status::OK();
|
||||
|
@ -557,19 +561,19 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
|||
|
||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetExecutorForThread(executor->executor());
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return new TFE_Executor(&context->Executor());
|
||||
}
|
||||
|
||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||
context->HostCPU()->parsed_name());
|
||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||
|
@ -585,7 +589,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
|||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto* function_def = context->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
|
@ -611,12 +615,13 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
|
|||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
if (ctx == nullptr || ctx->context == nullptr) {
|
||||
if (ctx == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid Context");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor(
|
||||
tensorflow::AbstractTensorInterface* t =
|
||||
tensorflow::unwrap(ctx)->CreateTensor(
|
||||
static_cast<tensorflow::DataType>(dtype), dimvec);
|
||||
|
||||
if (t == nullptr) {
|
||||
|
@ -630,5 +635,6 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
|
|||
|
||||
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
|
||||
TF_Status* status) {
|
||||
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)};
|
||||
return tensorflow::wrap(
|
||||
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
||||
}
|
||||
|
|
|
@ -19,13 +19,10 @@ limitations under the License.
|
|||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" // IWYU pragma: export
|
||||
|
||||
// TODO(b/154564140): Move this to its own header. This requires splitting
|
||||
// c_api_experimental.h
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
|
@ -233,7 +234,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
|||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
@ -245,7 +247,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
|||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!remote && !async) {
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
|
|
@ -27,6 +27,8 @@ limitations under the License.
|
|||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
|
@ -416,8 +418,10 @@ void TensorHandleSilentCopy(bool async,
|
|||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
auto cpu_arg = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||
auto gpu_arg = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||
auto cpu_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
|
||||
auto gpu_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
|
||||
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
|
||||
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
|
@ -1346,7 +1350,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
|||
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::EagerOperation* operation =
|
||||
tensorflow::OperationFromInterface(op->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
|
||||
operation->Attrs().FillAttrValueMap(&attr_values);
|
||||
return attr_values;
|
||||
}
|
||||
|
@ -1482,10 +1486,10 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
|||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInput(concatOp, dim, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK(concatOp->operation->OpDef());
|
||||
CHECK(tensorflow::unwrap(concatOp)->OpDef());
|
||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FALSE(concatOp->operation->OpDef())
|
||||
EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef())
|
||||
<< "Inference context is still present";
|
||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
@ -1590,7 +1594,7 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
|
|||
// There is currently no API to fetch attributes from an operation, fetching
|
||||
// happens only as an implementation detail of custom devices.
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(var_op->operation);
|
||||
OperationFromInterface(tensorflow::unwrap(var_op));
|
||||
TFE_OpAttrs attributes{&operation->Attrs()};
|
||||
|
||||
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
|
@ -1606,7 +1610,7 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
|
|||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(copy_op->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||
|
||||
|
@ -1630,7 +1634,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
|||
// There is currently no API to fetch attributes from an operation, fetching
|
||||
// happens only as an implementation detail of custom devices.
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(var_op->operation);
|
||||
OperationFromInterface(tensorflow::unwrap(var_op));
|
||||
TFE_OpAttrs attributes{&operation->Attrs()};
|
||||
|
||||
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||
|
@ -1657,7 +1661,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
|||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(var_op_2->operation);
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2));
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
|
||||
|
||||
|
|
|
@ -16,8 +16,10 @@ limitations under the License.
|
|||
#include "tensorflow/c/eager/dlpack.h"
|
||||
|
||||
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
|
@ -41,12 +43,12 @@ struct TfDlManagedTensorCtx {
|
|||
|
||||
// Gets tensor from eager tensor handle.
|
||||
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (handle->Type() != TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"DLPack doesn't support ", handle->TypeString(), " tensor");
|
||||
|
@ -107,7 +109,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
|
|||
// Gets DLPack's DLContext from eager tensor handle.
|
||||
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||
DLContext ctx;
|
||||
const char* device_name = h->handle->DeviceName(&status->status);
|
||||
const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status);
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||
std::string device_type = parsed_name.type;
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
|
||||
// Wraps a pointer to a context implementation.
|
||||
|
@ -23,8 +24,12 @@ limitations under the License.
|
|||
// interface cannot destruct the underlying context object. Instead, call
|
||||
// TFE_DeleteContext who calls Release() on the context pointer and deletes
|
||||
// the TFE_Context structure.
|
||||
struct TFE_Context {
|
||||
tensorflow::AbstractContextInterface* context;
|
||||
};
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
|
|
|
@ -23,14 +23,15 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
|
||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||
// that sometimes do not require serialization.
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : attributes(nullptr) {}
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
|
||||
// Wraps a pointer to an operation implementation.
|
||||
|
@ -23,8 +24,13 @@ limitations under the License.
|
|||
// interface cannot destruct the underlying operation object. Instead, call
|
||||
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
|
||||
// the TFE_Op structure.
|
||||
struct TFE_Op {
|
||||
tensorflow::AbstractOperationInterface* operation;
|
||||
};
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
|
||||
// Wraps a pointer to a tensor handle implementation.
|
||||
|
@ -23,8 +24,15 @@ limitations under the License.
|
|||
// interface cannot destruct the underlying handle object. Instead, call
|
||||
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
|
||||
// the TFE_TensorHandle structure.
|
||||
struct TFE_TensorHandle {
|
||||
tensorflow::AbstractTensorHandleInterface* handle;
|
||||
};
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
|
||||
TFE_TensorHandle);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
|
||||
TFE_TensorHandle*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
|
|
|
@ -22,13 +22,6 @@ package(
|
|||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "conversion_macros",
|
||||
hdrs = [
|
||||
"conversion_macros.h",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function",
|
||||
srcs = [
|
||||
|
@ -51,6 +44,7 @@ cc_library(
|
|||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
|
@ -93,7 +87,7 @@ cc_library(
|
|||
"concrete_function_type.h",
|
||||
],
|
||||
deps = [
|
||||
":conversion_macros",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
],
|
||||
)
|
||||
|
@ -123,7 +117,7 @@ cc_library(
|
|||
"function_metadata_type.h",
|
||||
],
|
||||
deps = [
|
||||
":conversion_macros",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
|
@ -24,7 +25,8 @@ limitations under the License.
|
|||
extern "C" {
|
||||
|
||||
TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(&tensorflow::unwrap(func)->GetFunctionMetadata());
|
||||
return tensorflow::wrap(const_cast<tensorflow::FunctionMetadata*>(
|
||||
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
||||
}
|
||||
|
||||
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
|
||||
|
@ -34,7 +36,7 @@ TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
|
|||
}
|
||||
|
||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
|
||||
return new TFE_Op{tensorflow::unwrap(func)->GetCallOp()};
|
||||
return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp());
|
||||
}
|
||||
|
||||
} // end extern "C"
|
||||
|
|
|
@ -16,8 +16,8 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
|
||||
|
||||
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
|
||||
inline cpp_impl *unwrap(wrapper *w) { \
|
||||
return reinterpret_cast<cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline wrapper *wrap(const cpp_impl *i) { \
|
||||
return reinterpret_cast<wrapper *>(const_cast<cpp_impl *>(i)); \
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONVERSION_MACROS_H_
|
|
@ -16,8 +16,8 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/conversion_macros.h"
|
||||
|
||||
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
|
||||
|
||||
|
|
|
@ -36,7 +36,8 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
|||
std::string saved_model_dir(dirname);
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
ctx->context->LoadSavedModelAPI(dirname, absl::nullopt, &status->status);
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -54,7 +55,7 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
|||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
ctx->context->LoadSavedModelAPI(dirname, std::move(tagset),
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
|
|
|
@ -903,7 +903,8 @@ cc_library(
|
|||
":safe_ptr",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -1007,7 +1008,8 @@ cc_library(
|
|||
":safe_ptr",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
|
|
|
@ -40,6 +40,9 @@ cc_library(
|
|||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:dlpack",
|
||||
"//tensorflow/c/eager:tape",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/hash/hash.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
|
@ -41,7 +41,7 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
|
|||
PyObject* value, tensorflow::DataType dtype,
|
||||
absl::string_view device_name) const {
|
||||
CHECK_NOTNULL(value);
|
||||
const auto& it = cache.find(Key{PyObjectPtr{value}, dtype, device_name});
|
||||
const auto it = cache.find(Key{PyObjectPtr{value}, dtype, device_name});
|
||||
if (it == cache.end()) {
|
||||
scalar_cache_misses->GetCell()->IncrementBy(1);
|
||||
return nullptr;
|
||||
|
@ -49,7 +49,7 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
|
|||
|
||||
scalar_cache_hits->GetCell()->IncrementBy(1);
|
||||
auto* h = it->second;
|
||||
return new TFE_TensorHandle{h->handle->Copy()};
|
||||
return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
|
||||
}
|
||||
|
||||
void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
|
||||
|
@ -57,7 +57,7 @@ void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
|
|||
TFE_TensorHandle* h) {
|
||||
Py_INCREF(value);
|
||||
cache.emplace(Key{PyObjectPtr{value}, dtype, device_name},
|
||||
new TFE_TensorHandle{h->handle->Copy()});
|
||||
tensorflow::wrap(tensorflow::unwrap(h)->Copy()));
|
||||
}
|
||||
|
||||
void TFE_TensorHandleCache::Clear() {
|
||||
|
|
|
@ -379,7 +379,7 @@ void TFE_Py_EnableInteractivePythonLogging();
|
|||
// Py_None.
|
||||
//
|
||||
// This function is not thread-safe.
|
||||
PyObject* TFE_Py_SetEagerContext(PyObject* python_context);
|
||||
PyObject* TFE_Py_SetEagerContext(PyObject* py_context);
|
||||
|
||||
// Returns the current eager Context object (defined in eager/context.py)
|
||||
// that was last set using TFE_Py_SetEagerContext.
|
||||
|
|
|
@ -24,6 +24,9 @@ limitations under the License.
|
|||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tape.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
@ -80,9 +83,10 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
|
|||
const char* raw_device_name, TF_Status* status) {
|
||||
auto op = ReleaseThreadLocalOp(ctx);
|
||||
if (!op) {
|
||||
op.reset(new TFE_Op{ctx->context->CreateOperation()});
|
||||
op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
|
||||
}
|
||||
status->status = op->operation->Reset(op_or_function_name, raw_device_name);
|
||||
status->status =
|
||||
tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
|
||||
if (!status->status.ok()) {
|
||||
op.reset();
|
||||
}
|
||||
|
@ -91,7 +95,7 @@ TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
|
|||
|
||||
void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
|
||||
if (op) {
|
||||
op->operation->Clear();
|
||||
tensorflow::unwrap(op)->Clear();
|
||||
thread_local_eager_operation_map[ctx].reset(op);
|
||||
}
|
||||
}
|
||||
|
@ -1500,7 +1504,7 @@ static PyTypeObject TFE_Py_Tape_Type = {
|
|||
sizeof(TFE_Py_Tape), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
&TFE_Py_Tape_Delete, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
nullptr, /* tp_print */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
|
@ -1538,7 +1542,7 @@ static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
|
|||
sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
&TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
nullptr, /* tp_print */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
|
@ -1573,7 +1577,7 @@ static PyTypeObject TFE_Py_VariableWatcher_Type = {
|
|||
sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
&TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
nullptr, /* tp_print */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
|
@ -1990,21 +1994,22 @@ bool ListContainsNone(PyObject* list) {
|
|||
|
||||
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
|
||||
if (EagerTensor_CheckExact(tensor)) {
|
||||
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
|
||||
tensorflow::AbstractTensorHandleInterface* handle =
|
||||
tensorflow::unwrap(EagerTensor_Handle(tensor));
|
||||
tensorflow::int64 id = PyEagerTensor_ID(tensor);
|
||||
tensorflow::DataType dtype =
|
||||
static_cast<tensorflow::DataType>(t->handle->DataType());
|
||||
static_cast<tensorflow::DataType>(handle->DataType());
|
||||
if (dtype == tensorflow::DT_VARIANT) {
|
||||
return PyTapeTensor(id, dtype, tensor);
|
||||
}
|
||||
|
||||
tensorflow::TensorShape tensor_shape;
|
||||
int num_dims;
|
||||
tensorflow::Status status = t->handle->NumDims(&num_dims);
|
||||
tensorflow::Status status = handle->NumDims(&num_dims);
|
||||
if (status.ok()) {
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
tensorflow::int64 dim_size;
|
||||
status = t->handle->Dim(i, &dim_size);
|
||||
status = handle->Dim(i, &dim_size);
|
||||
if (!status.ok()) break;
|
||||
tensor_shape.AddDim(dim_size);
|
||||
}
|
||||
|
@ -3511,7 +3516,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
const tensorflow::OpDef* op_def = op->operation->OpDef();
|
||||
const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
|
||||
if (op_def == nullptr) return nullptr;
|
||||
|
||||
if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
|
||||
|
@ -3850,14 +3855,15 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
|
|||
bool include_tensor_ranks_only,
|
||||
EncodeResult* result) {
|
||||
if (EagerTensor_CheckExact(arg)) {
|
||||
TFE_TensorHandle* t = EagerTensor_Handle(arg);
|
||||
tensorflow::AbstractTensorHandleInterface* handle =
|
||||
tensorflow::unwrap(EagerTensor_Handle(arg));
|
||||
|
||||
absl::StrAppend(&result->str, kDType,
|
||||
static_cast<tensorflow::DataType>(t->handle->DataType()));
|
||||
static_cast<tensorflow::DataType>(handle->DataType()));
|
||||
absl::StrAppend(&result->str, kShape);
|
||||
|
||||
int num_dims;
|
||||
tensorflow::Status status = t->handle->NumDims(&num_dims);
|
||||
tensorflow::Status status = handle->NumDims(&num_dims);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
if (include_tensor_ranks_only) {
|
||||
|
@ -3865,7 +3871,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
|
|||
} else {
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
tensorflow::int64 dim_size;
|
||||
status = t->handle->Dim(i, &dim_size);
|
||||
status = handle->Dim(i, &dim_size);
|
||||
if (!status.ok()) return status;
|
||||
absl::StrAppend(&result->str, dim_size, kShapeDelim);
|
||||
}
|
||||
|
|
|
@ -26,7 +26,8 @@ limitations under the License.
|
|||
|
||||
#include "numpy/arrayobject.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
|
@ -95,8 +96,8 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
|
|||
if (call->eager) {
|
||||
Tensor t = call->ins[i];
|
||||
arg = EagerTensorFromHandle(
|
||||
new TFE_TensorHandle{TensorHandle::CreateLocalHandle(
|
||||
std::move(t), ctx->CanonicalDevice(device), nullptr, ctx)});
|
||||
tensorflow::wrap(TensorHandle::CreateLocalHandle(
|
||||
std::move(t), ctx->CanonicalDevice(device), nullptr, ctx)));
|
||||
if (arg == nullptr) {
|
||||
Py_DECREF(lst);
|
||||
return errors::Internal("Unable to procure EagerTensor from Tensor.");
|
||||
|
@ -146,7 +147,7 @@ tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
|
|||
const Device* expected_device,
|
||||
const Tensor** output_tensor) {
|
||||
tensorflow::TensorHandle* handle = tensorflow::TensorHandleFromInterface(
|
||||
EagerTensor_Handle(eager_tensor)->handle);
|
||||
tensorflow::unwrap(EagerTensor_Handle(eager_tensor)));
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
return errors::Unimplemented(
|
||||
"Custom devices are currently not supported with PyFuncs.");
|
||||
|
@ -196,7 +197,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
|||
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
|
||||
PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
|
||||
CHECK_NE(ctx, nullptr);
|
||||
EagerContext* context = ContextFromInterface(ctx->context);
|
||||
EagerContext* context = ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
TF_RETURN_IF_ERROR(MakeArgTuple(call, context, &args));
|
||||
new_executor.reset(new EagerExecutor(call->eager_async));
|
||||
old_executor = &context->Executor();
|
||||
|
@ -236,7 +237,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
|||
if (new_executor != nullptr) {
|
||||
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(
|
||||
PyObject_GetAttrString(trampoline, "_ctx"), nullptr));
|
||||
EagerContext* context = ContextFromInterface(ctx->context);
|
||||
EagerContext* context = ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
s.Update(new_executor->WaitForAllPendingNodes());
|
||||
context->SetExecutorForThread(old_executor);
|
||||
}
|
||||
|
|
|
@ -15,7 +15,8 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/python/lib/core/py_seq_tensor.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -305,7 +306,7 @@ struct Converter {
|
|||
}
|
||||
}
|
||||
}
|
||||
*h = new TFE_TensorHandle{ctx->context->CreateLocalHandle(t)};
|
||||
*h = tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(t));
|
||||
t->Release();
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -316,12 +317,12 @@ struct Converter {
|
|||
template <>
|
||||
struct ConverterTraits<int64> {
|
||||
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, int64 value) {
|
||||
return ctx->context->CreateInt64Scalar(value);
|
||||
return tensorflow::unwrap(ctx)->CreateInt64Scalar(value);
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateTensor(DT_INT64, dim_sizes);
|
||||
return tensorflow::unwrap(ctx)->CreateTensor(DT_INT64, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, int64* out) {
|
||||
|
@ -356,12 +357,12 @@ typedef Converter<int64> Int64Converter;
|
|||
template <>
|
||||
struct ConverterTraits<uint64> {
|
||||
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, uint64 value) {
|
||||
return ctx->context->CreateUint64Scalar(value);
|
||||
return tensorflow::unwrap(ctx)->CreateUint64Scalar(value);
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateTensor(DT_UINT64, dim_sizes);
|
||||
return tensorflow::unwrap(ctx)->CreateTensor(DT_UINT64, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, uint64* out) {
|
||||
|
@ -393,12 +394,12 @@ typedef Converter<uint64> UInt64Converter;
|
|||
template <>
|
||||
struct ConverterTraits<int32> {
|
||||
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, int32 value) {
|
||||
return ctx->context->CreateInt32Scalar(value);
|
||||
return tensorflow::unwrap(ctx)->CreateInt32Scalar(value);
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateTensor(DT_INT32, dim_sizes);
|
||||
return tensorflow::unwrap(ctx)->CreateTensor(DT_INT32, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, int32* out) {
|
||||
|
@ -500,12 +501,12 @@ static const char* ConvertOneFloat(PyObject* v, T* out) {
|
|||
template <>
|
||||
struct ConverterTraits<float> {
|
||||
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, float value) {
|
||||
return ctx->context->CreateFloatScalar(value);
|
||||
return tensorflow::unwrap(ctx)->CreateFloatScalar(value);
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateTensor(DT_FLOAT, dim_sizes);
|
||||
return tensorflow::unwrap(ctx)->CreateTensor(DT_FLOAT, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, float* out) {
|
||||
|
@ -516,12 +517,12 @@ struct ConverterTraits<float> {
|
|||
template <>
|
||||
struct ConverterTraits<double> {
|
||||
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, double value) {
|
||||
return ctx->context->CreateDoubleScalar(value);
|
||||
return tensorflow::unwrap(ctx)->CreateDoubleScalar(value);
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateTensor(DT_DOUBLE, dim_sizes);
|
||||
return tensorflow::unwrap(ctx)->CreateTensor(DT_DOUBLE, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, double* out) {
|
||||
|
@ -536,12 +537,12 @@ template <>
|
|||
struct ConverterTraits<Eigen::half> {
|
||||
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx,
|
||||
Eigen::half value) {
|
||||
return ctx->context->CreateHalfScalar(value);
|
||||
return tensorflow::unwrap(ctx)->CreateHalfScalar(value);
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateTensor(DT_HALF, dim_sizes);
|
||||
return tensorflow::unwrap(ctx)->CreateTensor(DT_HALF, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, Eigen::half* out) {
|
||||
|
@ -557,12 +558,12 @@ template <>
|
|||
struct ConverterTraits<tstring> {
|
||||
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx,
|
||||
tstring value) {
|
||||
return ctx->context->CreateStringScalar(value);
|
||||
return tensorflow::unwrap(ctx)->CreateStringScalar(value);
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateTensor(DT_STRING, dim_sizes);
|
||||
return tensorflow::unwrap(ctx)->CreateTensor(DT_STRING, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, tstring* out) {
|
||||
|
@ -624,12 +625,12 @@ template <>
|
|||
struct ConverterTraits<complex128> {
|
||||
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx,
|
||||
complex128 value) {
|
||||
return ctx->context->CreateComplex128Scalar(value);
|
||||
return tensorflow::unwrap(ctx)->CreateComplex128Scalar(value);
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateTensor(DT_COMPLEX128, dim_sizes);
|
||||
return tensorflow::unwrap(ctx)->CreateTensor(DT_COMPLEX128, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, complex128* out) {
|
||||
|
@ -652,12 +653,12 @@ typedef Converter<complex128> Complex128Converter;
|
|||
template <>
|
||||
struct ConverterTraits<bool> {
|
||||
static AbstractTensorInterface* CreateScalar(TFE_Context* ctx, bool value) {
|
||||
return ctx->context->CreateBoolScalar(value);
|
||||
return tensorflow::unwrap(ctx)->CreateBoolScalar(value);
|
||||
}
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateTensor(DT_BOOL, dim_sizes);
|
||||
return tensorflow::unwrap(ctx)->CreateTensor(DT_BOOL, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, bool* out) {
|
||||
|
@ -692,7 +693,7 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
|
|||
}
|
||||
|
||||
TensorInterface t(std::move(tensor));
|
||||
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(&t)};
|
||||
return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -877,7 +878,7 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
|
|||
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
|
||||
TensorShape(state.inferred_shape));
|
||||
TensorInterface t(std::move(tensor));
|
||||
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(&t)};
|
||||
return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t));
|
||||
}
|
||||
|
||||
default:
|
||||
|
|
|
@ -270,7 +270,6 @@ static py::object TFE_ClearScalarCache() {
|
|||
// are only assigning this to functions that return opaque types.
|
||||
|
||||
PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
py::class_<TFE_Context> TFE_Context_class(m, "TFE_Context");
|
||||
py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
|
||||
py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
|
||||
"TFE_ContextOptions");
|
||||
|
@ -760,7 +759,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||
m.def("TFE_ContextStartStep", [](py::handle& o) {
|
||||
TFE_ContextStartStep(tensorflow::InputTFE_Context(o.ptr()));
|
||||
});
|
||||
m.def("TFE_ContextEndStep", &TFE_ContextEndStep);
|
||||
m.def("TFE_ContextEndStep", [](py::handle& o) {
|
||||
TFE_ContextEndStep(tensorflow::InputTFE_Context(o.ptr()));
|
||||
});
|
||||
m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
|
||||
return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue