Add layer of indirection for Tensor & TensorHandle
We add the TensorInterface & TensorHandleInterface classes and keep them as the sole member of TF_Tensor and TFE_TensorHandle structs to keep those structs simple. This allows us to keep most of the C API functions as simple wrappers around C++ classes. PiperOrigin-RevId: 288903948 Change-Id: I9f4d8914c447145df63c8518bcde60656f7098f9
This commit is contained in:
parent
243685515f
commit
96f40ae009
@ -28,6 +28,7 @@ tf_cuda_library(
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.cc",
|
||||
"c_api_internal.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
@ -93,6 +94,7 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
@ -102,7 +104,10 @@ filegroup(
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = ["c_api_experimental.h"],
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
|
@ -630,7 +630,8 @@ tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
|
||||
}
|
||||
const std::string& type_attr = input_def.type_attr();
|
||||
if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
|
||||
op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
|
||||
op->operation.MutableAttrs()->Set(
|
||||
type_attr, static_cast<tensorflow::DataType>(input->handle.DataType()));
|
||||
ictx->attrs.insert(type_attr);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
@ -671,13 +672,16 @@ tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
|
||||
if (!input_def.type_list_attr().empty()) {
|
||||
std::vector<tensorflow::DataType> dtypes(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
dtypes[i] = inputs[i]->handle->dtype;
|
||||
dtypes[i] =
|
||||
static_cast<const tensorflow::DataType>(inputs[i]->handle.DataType());
|
||||
}
|
||||
OpInferMixedTypeInputListAttrs(op, input_def, dtypes);
|
||||
} else if (!input_def.type_attr().empty() &&
|
||||
!input_def.number_attr().empty()) {
|
||||
OpInferSingleTypeInputListAttrs(op, input_def, inputs[0]->handle->dtype,
|
||||
num_inputs);
|
||||
OpInferSingleTypeInputListAttrs(
|
||||
op, input_def,
|
||||
static_cast<const tensorflow::DataType>(inputs[0]->handle.DataType()),
|
||||
num_inputs);
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument("Invalid input list definition");
|
||||
}
|
||||
@ -745,12 +749,9 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
|
||||
|
||||
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
TF_DeviceList* list = new TF_DeviceList;
|
||||
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
|
||||
if (ctx->context->remote_device_mgr()) {
|
||||
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
|
||||
}
|
||||
return list;
|
||||
TF_DeviceList* l = new TF_DeviceList;
|
||||
ctx->context->ListDevices(&l->response);
|
||||
return l;
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
@ -886,138 +887,209 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
if (h == nullptr) return;
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
|
||||
<< h->handle;
|
||||
if (h->handle) {
|
||||
h->handle->Unref();
|
||||
}
|
||||
delete h;
|
||||
}
|
||||
|
||||
tensorflow::TensorHandleInterface::~TensorHandleInterface() {
|
||||
VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
|
||||
<< handle_;
|
||||
if (handle_) {
|
||||
handle_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
|
||||
if (handle_ == nullptr) {
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
||||
return static_cast<TF_DataType>(h->handle->dtype);
|
||||
return h->handle.DataType();
|
||||
}
|
||||
|
||||
TF_DataType tensorflow::TensorHandleInterface::DataType() const {
|
||||
return static_cast<TF_DataType>(handle_->dtype);
|
||||
}
|
||||
|
||||
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return h->handle.NumDims(&status->status);
|
||||
}
|
||||
|
||||
int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
int result;
|
||||
status->status = h->handle->NumDims(&result);
|
||||
*status = handle_->NumDims(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return h->handle.NumElements(&status->status);
|
||||
}
|
||||
|
||||
int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
tensorflow::int64 result;
|
||||
status->status = h->handle->NumElements(&result);
|
||||
*status = handle_->NumElements(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return h->handle.Dim(dim_index, &status->status);
|
||||
}
|
||||
|
||||
int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
|
||||
Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
tensorflow::int64 result;
|
||||
status->status = h->handle->Dim(dim_index, &result);
|
||||
*status = handle_->Dim(dim_index, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = h->handle->op_device();
|
||||
return h->handle.DeviceName(&status->status);
|
||||
}
|
||||
|
||||
const char* tensorflow::TensorHandleInterface::DeviceName(
|
||||
Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = handle_->op_device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
|
||||
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = h->handle->device();
|
||||
return h->handle.BackingDeviceName(&status->status);
|
||||
}
|
||||
|
||||
const char* tensorflow::TensorHandleInterface::BackingDeviceName(
|
||||
Status* status) const {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = handle_->device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr || !h->handle.IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
h->handle->Ref();
|
||||
return h->handle.Copy();
|
||||
}
|
||||
|
||||
return new TFE_TensorHandle(h->handle);
|
||||
TFE_TensorHandle* tensorflow::TensorHandleInterface::Copy() {
|
||||
handle_->Ref();
|
||||
return new TFE_TensorHandle{TensorHandleInterface(handle_)};
|
||||
}
|
||||
|
||||
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
|
||||
return h->handle.Resolve(&status->status);
|
||||
}
|
||||
|
||||
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
|
||||
if (handle->IsRemote()) {
|
||||
if (handle_->IsRemote()) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||
status->status = EagerCopyToDevice(
|
||||
handle, handle->Context(), &handle->Context()->Executor(),
|
||||
handle->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
*status = EagerCopyToDevice(handle_, handle_->Context(),
|
||||
&handle_->Context()->Executor(),
|
||||
handle_->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
status->status = h_cpu->Tensor(&t);
|
||||
if (!status->status.ok()) {
|
||||
*status = h_cpu->Tensor(&t);
|
||||
if (!status->ok()) {
|
||||
h_cpu->Unref();
|
||||
return nullptr;
|
||||
}
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, &status->status);
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
|
||||
h_cpu->Unref();
|
||||
return retval;
|
||||
} else {
|
||||
tensorflow::Tensor tensor;
|
||||
if (IsCPU(handle->device())) {
|
||||
if (IsCPU(handle_->device())) {
|
||||
const tensorflow::Tensor* src = nullptr;
|
||||
status->status = handle->Tensor(&src);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
*status = handle_->Tensor(&src);
|
||||
if (!status->ok()) return nullptr;
|
||||
tensor = *src;
|
||||
} else {
|
||||
tensorflow::EagerContext* ctx = handle->Context();
|
||||
tensorflow::EagerContext* ctx = handle_->Context();
|
||||
CHECK_NE(ctx, nullptr);
|
||||
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
*status = handle_->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->ok()) return nullptr;
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
|
||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||
}
|
||||
}
|
||||
|
||||
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr || !h->handle.IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
tensorflow::TensorHandle* handle = h->handle.Handle();
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -1078,7 +1150,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(ret_handle);
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(ret_handle)};
|
||||
}
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
@ -1086,12 +1158,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
// bytes of the memory pointed to by the device pointer returned above.
|
||||
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
if (h == nullptr || !h->handle.IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return 0;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
tensorflow::TensorHandle* handle = h->handle.Handle();
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -1135,16 +1207,20 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
op->operation.AddInput(input->handle);
|
||||
if (op->inference_ctx) {
|
||||
status->status = OpInferSingleInputAttrs(op, input);
|
||||
return op->AddInput(input, status);
|
||||
}
|
||||
|
||||
void TFE_Op::AddInput(TFE_TensorHandle* input, TF_Status* status) {
|
||||
operation.AddInput(input->handle.Handle());
|
||||
if (inference_ctx) {
|
||||
status->status = OpInferSingleInputAttrs(this, input);
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
op->operation.AddInput(inputs[i]->handle);
|
||||
op->operation.AddInput(inputs[i]->handle.Handle());
|
||||
}
|
||||
if (op->inference_ctx) {
|
||||
status->status = OpInferInputListAttrs(op, inputs, num_inputs);
|
||||
@ -1382,14 +1458,20 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Calling TFE_Execute() on op " << op;
|
||||
op->Execute(retvals, num_retvals, status);
|
||||
}
|
||||
|
||||
void TFE_Op::Execute(TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
status->status = tensorflow::EagerExecute(&op->operation,
|
||||
handle_retvals.data(), num_retvals);
|
||||
status->status =
|
||||
tensorflow::EagerExecute(&operation, handle_retvals.data(), num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
|
||||
retvals[i] = new TFE_TensorHandle{
|
||||
tensorflow::TensorHandleInterface(handle_retvals[i])};
|
||||
}
|
||||
}
|
||||
|
||||
@ -1403,11 +1485,11 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
&ctx->context->Executor(),
|
||||
device, false, &handle);
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
h->handle.Handle(), ctx->context, &ctx->context->Executor(), device,
|
||||
false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)};
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -28,19 +28,22 @@ using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
|
||||
TF_Status* status) {
|
||||
std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle,
|
||||
tensorflow::Status* status) {
|
||||
std::vector<int64> shape;
|
||||
int rank = TFE_TensorHandleNumDims(handle, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
int rank = -1;
|
||||
*status = handle.NumDims(&rank);
|
||||
if (!status->ok()) {
|
||||
return shape;
|
||||
}
|
||||
shape.reserve(rank);
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
shape.push_back(TFE_TensorHandleDim(handle, i, status));
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
tensorflow::int64 dim;
|
||||
*status = handle.Dim(i, &dim);
|
||||
if (!status->ok()) {
|
||||
return shape;
|
||||
}
|
||||
shape.push_back(dim);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
@ -51,14 +54,19 @@ extern "C" {
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
return h->handle.TensorDebugInfo(&status->status);
|
||||
}
|
||||
|
||||
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
Status* status) {
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = h->handle->Tensor(&tensor);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
*status = handle_->Tensor(&tensor);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Device* device = h->handle->device();
|
||||
tensorflow::Device* device = handle_->device();
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
tensorflow::XlaDevice* xla_device =
|
||||
@ -67,15 +75,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
|
||||
xla_device->metadata().padded_shape_fn();
|
||||
xla::Shape padded_shape;
|
||||
status->status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->status.ok()) {
|
||||
*status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VLOG_IS_ON(3)) {
|
||||
std::vector<int64> shape_to_log = TensorShapeAsVector(h, status);
|
||||
if (!status->status.ok()) {
|
||||
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
|
||||
if (!status->ok()) {
|
||||
// Ignore the status here as we are simply logging.
|
||||
status->status = tensorflow::Status::OK();
|
||||
*status = tensorflow::Status::OK();
|
||||
} else {
|
||||
VLOG(3) << "Fully padded shape of ["
|
||||
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||
@ -88,7 +96,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
// Currently, the only case of XlaTensor containing a tuple shape is to
|
||||
// represent 64 bit ints, doubles, and complex numbers (we don't support
|
||||
// 64bit complex numbers).
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should only contain tuples of size 2. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
@ -100,13 +108,13 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
const xla::Shape& shape1 =
|
||||
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
|
||||
if (shape0.IsTuple() || shape1.IsTuple()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should not contain nested tuples. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
"Subshapes of XlaTensors should be the same. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
@ -131,15 +139,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
dev_dims.push_back(padded_shape.dimensions(dim_index));
|
||||
}
|
||||
}
|
||||
status->status = tensorflow::Status::OK();
|
||||
*status = tensorflow::Status::OK();
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
// If the tensor is not an XLA tensor, the device shape is
|
||||
// the same as regular tensor shape.
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(h, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
|
@ -41,7 +41,7 @@ void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
||||
}
|
||||
|
||||
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
op->operation.ConsumeInput(h->handle);
|
||||
op->operation.ConsumeInput(h->handle.Handle());
|
||||
}
|
||||
|
||||
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -91,7 +92,6 @@ struct TFE_Context {
|
||||
};
|
||||
|
||||
struct TFE_TensorHandle {
|
||||
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
|
||||
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
|
||||
TF_Status* s) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
@ -99,10 +99,10 @@ struct TFE_TensorHandle {
|
||||
if (!s->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(handle);
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)};
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* handle;
|
||||
tensorflow::TensorHandleInterface handle;
|
||||
};
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
@ -144,6 +144,9 @@ struct TFE_Op {
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void AddInput(TFE_TensorHandle* input, TF_Status* status);
|
||||
void Execute(TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status);
|
||||
|
||||
TFE_Context* ctx;
|
||||
tensorflow::EagerOperation operation;
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
|
52
tensorflow/c/eager/tensor_handle_interface.h
Normal file
52
tensorflow/c/eager/tensor_handle_interface.h
Normal file
@ -0,0 +1,52 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorHandleInterface {
|
||||
public:
|
||||
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
|
||||
~TensorHandleInterface();
|
||||
|
||||
bool IsValid(Status* status) const;
|
||||
TF_DataType DataType() const;
|
||||
int NumDims(Status* status) const;
|
||||
int64_t NumElements(Status* status) const;
|
||||
int64_t Dim(int dim_index, Status* status) const;
|
||||
|
||||
const char* DeviceName(Status* status) const;
|
||||
const char* BackingDeviceName(Status* status) const;
|
||||
TFE_TensorHandle* Copy();
|
||||
TF_Tensor* Resolve(Status* status);
|
||||
TFE_TensorDebugInfo* TensorDebugInfo(Status* status);
|
||||
|
||||
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
||||
// use cases.
|
||||
TensorHandle* Handle() { return handle_; }
|
||||
|
||||
private:
|
||||
TensorHandle* handle_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
@ -103,9 +103,9 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
}
|
||||
|
||||
TF_Tensor* ret =
|
||||
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf)};
|
||||
TF_Tensor* ret = new TF_Tensor{tensorflow::TensorInterface(
|
||||
Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf))};
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) {
|
||||
@ -115,37 +115,23 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
return ret;
|
||||
}
|
||||
|
||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
|
||||
// It is safe to move the Tensor if and only if we own the unique reference to
|
||||
// it. In that case, we might as well not delete and reallocate, but a future
|
||||
// implementation might need to do so.
|
||||
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
|
||||
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
|
||||
buf->OwnsMemory()) {
|
||||
return tensor;
|
||||
}
|
||||
return nullptr;
|
||||
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
|
||||
return t->tensor.CanMove() ? t : nullptr;
|
||||
}
|
||||
|
||||
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
|
||||
|
||||
TF_DataType TF_TensorType(const TF_Tensor* t) {
|
||||
return static_cast<TF_DataType>(t->tensor.dtype());
|
||||
}
|
||||
TF_DataType TF_TensorType(const TF_Tensor* t) { return t->tensor.Type(); }
|
||||
|
||||
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); }
|
||||
int TF_NumDims(const TF_Tensor* t) { return t->tensor.NumDims(); }
|
||||
|
||||
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
|
||||
return static_cast<int64_t>(t->tensor.dim_size(dim_index));
|
||||
return t->tensor.Dim(dim_index);
|
||||
}
|
||||
|
||||
size_t TF_TensorByteSize(const TF_Tensor* t) {
|
||||
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
|
||||
}
|
||||
size_t TF_TensorByteSize(const TF_Tensor* t) { return t->tensor.ByteSize(); }
|
||||
|
||||
void* TF_TensorData(const TF_Tensor* t) {
|
||||
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
|
||||
}
|
||||
void* TF_TensorData(const TF_Tensor* t) { return t->tensor.Data(); }
|
||||
|
||||
int64_t TF_TensorElementCount(const TF_Tensor* t) {
|
||||
int64_t result = 1;
|
||||
@ -160,15 +146,60 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
|
||||
TF_Tensor* to, const int64_t* new_dims,
|
||||
int num_new_dims, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
Status cc_status(
|
||||
to->tensor.BitcastFrom(from->tensor, type, new_dims, num_new_dims));
|
||||
Set_TF_Status_from_Status(status, cc_status);
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool TensorInterface::CanMove() const {
|
||||
// It is safe to move the Tensor if and only if we own the unique reference to
|
||||
// it. In that case, we might as well not delete and reallocate, but a future
|
||||
// implementation might need to do so.
|
||||
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor_);
|
||||
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
|
||||
buf->OwnsMemory()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_DataType TensorInterface::Type() const {
|
||||
return static_cast<TF_DataType>(tensor_.dtype());
|
||||
}
|
||||
|
||||
int TensorInterface::NumDims() const { return tensor_.dims(); }
|
||||
|
||||
int64_t TensorInterface::Dim(int dim_index) const {
|
||||
return static_cast<int64_t>(tensor_.dim_size(dim_index));
|
||||
}
|
||||
|
||||
int64_t TensorInterface::NumElements() const {
|
||||
return static_cast<int64_t>(tensor_.NumElements());
|
||||
}
|
||||
|
||||
size_t TensorInterface::ByteSize() const {
|
||||
return tensorflow::TensorCApi::Buffer(tensor_)->size();
|
||||
}
|
||||
|
||||
void* TensorInterface::Data() const {
|
||||
return tensorflow::TensorCApi::Buffer(tensor_)->data();
|
||||
}
|
||||
|
||||
Status TensorInterface::BitcastFrom(const TensorInterface& from,
|
||||
TF_DataType type, const int64_t* new_dims,
|
||||
int num_new_dims) {
|
||||
tensorflow::TensorShape s;
|
||||
for (int i = 0; i < num_new_dims; ++i) {
|
||||
s.AddDim(new_dims[i]);
|
||||
}
|
||||
Status cc_status(to->tensor.BitcastFrom(
|
||||
from->tensor, static_cast<tensorflow::DataType>(type), s));
|
||||
Set_TF_Status_from_Status(status, cc_status);
|
||||
return tensor_.BitcastFrom(from.tensor_,
|
||||
static_cast<tensorflow::DataType>(type), s);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
void StringEncode(const char* src, size_t src_len, char* dst) {
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
@ -332,31 +363,34 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
|
||||
}
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
if (src->tensor.dtype() == DT_RESOURCE) {
|
||||
if (src->tensor.dims() != 0) {
|
||||
return src->tensor.ToTensor(dst);
|
||||
}
|
||||
|
||||
Status TensorInterface::ToTensor(Tensor* dst) const {
|
||||
if (tensor_.dtype() == DT_RESOURCE) {
|
||||
if (tensor_.dims() != 0) {
|
||||
return InvalidArgument(
|
||||
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
|
||||
"shape ",
|
||||
src->tensor.shape().DebugString());
|
||||
tensor_.shape().DebugString());
|
||||
}
|
||||
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape());
|
||||
*dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
|
||||
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
|
||||
string(static_cast<const char*>(TF_TensorData(src)),
|
||||
TF_TensorByteSize(src)))) {
|
||||
string(static_cast<const char*>(Data()), ByteSize()))) {
|
||||
return InvalidArgument(
|
||||
"Malformed TF_RESOUCE tensor: unable to parse resource handle");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
if (src->tensor.dtype() != DT_STRING) {
|
||||
*dst = src->tensor;
|
||||
if (tensor_.dtype() != DT_STRING) {
|
||||
*dst = tensor_;
|
||||
return Status::OK();
|
||||
}
|
||||
// TF_STRING tensors require copying since Tensor class expects a sequence of
|
||||
// string objects.
|
||||
const tensorflow::int64 num_elements = src->tensor.NumElements();
|
||||
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
|
||||
const size_t src_size = TF_TensorByteSize(src);
|
||||
const tensorflow::int64 num_elements = tensor_.NumElements();
|
||||
const char* input = reinterpret_cast<const char*>(Data());
|
||||
const size_t src_size = ByteSize();
|
||||
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
|
||||
num_elements) {
|
||||
return InvalidArgument(
|
||||
@ -365,7 +399,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
||||
const char* limit = input + src_size;
|
||||
|
||||
*dst = Tensor(src->tensor.dtype(), src->tensor.shape());
|
||||
*dst = Tensor(tensor_.dtype(), tensor_.shape());
|
||||
auto dstarray = dst->flat<tstring>();
|
||||
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
||||
tensorflow::uint64 offset =
|
||||
@ -384,8 +418,12 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TensorInterface::CopyFrom(const Tensor& other, const TensorShape& shape) {
|
||||
return tensor_.CopyFrom(other, shape);
|
||||
}
|
||||
|
||||
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
bool TF_TensorIsAligned(const TF_Tensor* tensor) {
|
||||
return tensor->tensor.IsAligned();
|
||||
}
|
||||
bool TF_TensorIsAligned(const TF_Tensor* t) { return t->tensor.IsAligned(); }
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
// Internal structures used by the C API. These are likely to change and should
|
||||
@ -28,7 +29,7 @@ limitations under the License.
|
||||
// passed to or returned from C functions *by pointer*. Otherwise, changes to
|
||||
// its internal structure will break the C API's binary interface.
|
||||
typedef struct TF_Tensor {
|
||||
::tensorflow::Tensor tensor;
|
||||
tensorflow::TensorInterface tensor;
|
||||
} TF_Tensor;
|
||||
|
||||
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
@ -83,4 +84,5 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
|
||||
// a different Allocator as `arg`.
|
||||
void deallocate_buffer(void* data, size_t len, void* arg);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
|
@ -637,6 +637,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/framework:shared_ptr_variant.h",
|
||||
"//tensorflow/core/framework:stats_aggregator.h",
|
||||
"//tensorflow/core/framework:tensor.h",
|
||||
"//tensorflow/core/framework:tensor_interface.h",
|
||||
"//tensorflow/core/framework:tensor_shape.h",
|
||||
"//tensorflow/core/framework:tensor_slice.h",
|
||||
"//tensorflow/core/framework:tensor_types.h",
|
||||
|
@ -388,6 +388,14 @@ std::vector<const FunctionDef*> EagerContext::ListRegisteredFunctions() {
|
||||
|
||||
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
|
||||
|
||||
void EagerContext::ListDevices(
|
||||
std::vector<tensorflow::DeviceAttributes>* devices) {
|
||||
local_device_mgr()->ListDeviceAttributes(devices);
|
||||
if (remote_device_mgr()) {
|
||||
remote_device_mgr()->ListDeviceAttributes(devices);
|
||||
}
|
||||
}
|
||||
|
||||
void EagerContext::StartStep() {
|
||||
mutex_lock ml(metadata_mu_);
|
||||
num_active_steps_++;
|
||||
|
@ -251,6 +251,8 @@ class EagerContext : public core::RefCounted {
|
||||
RunMetadata* RunMetadataProto() { return &run_metadata_; }
|
||||
void ClearRunMetadata() EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
|
||||
|
||||
void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices);
|
||||
|
||||
void StartStep();
|
||||
void EndStep();
|
||||
ScopedStepContainer* StepContainer();
|
||||
|
@ -345,6 +345,7 @@ filegroup(
|
||||
"stats_aggregator.h",
|
||||
"tensor.cc",
|
||||
"tensor.h",
|
||||
"tensor_interface.h",
|
||||
"tensor_reference.h",
|
||||
"tensor_shape.cc",
|
||||
"tensor_shape.h",
|
||||
@ -902,6 +903,7 @@ exports_files(
|
||||
"resource_handle.h",
|
||||
"shape_inference_testutil.h",
|
||||
"tensor.h",
|
||||
"tensor_interface.h",
|
||||
"tensor_shape.h",
|
||||
"tensor_testutil.h",
|
||||
"tensor_types.h",
|
||||
|
54
tensorflow/core/framework/tensor_interface.h
Normal file
54
tensorflow/core/framework/tensor_interface.h
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
// Internal structures used by the C API. These are likely to change and should
|
||||
// not be depended on.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorInterface {
|
||||
public:
|
||||
TensorInterface() {}
|
||||
explicit TensorInterface(Tensor t) : tensor_(std::move(t)) {}
|
||||
|
||||
TF_DataType Type() const;
|
||||
int NumDims() const;
|
||||
int64_t Dim(int dim_index) const;
|
||||
int64_t NumElements() const;
|
||||
size_t ByteSize() const;
|
||||
void* Data() const;
|
||||
bool IsAligned() const;
|
||||
|
||||
Status ToTensor(Tensor* dst) const;
|
||||
bool CopyFrom(const Tensor& other, const TensorShape& shape);
|
||||
Status BitcastFrom(const TensorInterface& from, TF_DataType type,
|
||||
const int64_t* new_dims, int num_new_dims);
|
||||
|
||||
bool CanMove() const;
|
||||
|
||||
private:
|
||||
Tensor tensor_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
|
@ -90,7 +90,7 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(handle);
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)};
|
||||
}
|
||||
|
||||
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
|
||||
@ -268,7 +268,7 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* value,
|
||||
return nullptr;
|
||||
}
|
||||
CHECK_NE(handle, nullptr);
|
||||
return new TFE_TensorHandle(handle);
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)};
|
||||
}
|
||||
|
||||
TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
|
||||
|
@ -48,17 +48,15 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
|
||||
}
|
||||
|
||||
scalar_cache_hits->GetCell()->IncrementBy(1);
|
||||
auto* handle = it->second;
|
||||
handle->Ref();
|
||||
return new TFE_TensorHandle(handle);
|
||||
auto* h = it->second;
|
||||
return h->handle.Copy();
|
||||
}
|
||||
|
||||
void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
|
||||
absl::string_view device_name,
|
||||
TFE_TensorHandle* handle) {
|
||||
TFE_TensorHandle* h) {
|
||||
Py_INCREF(value);
|
||||
handle->handle->Ref();
|
||||
cache.emplace(Key{PyObjectPtr{value}, dtype, device_name}, handle->handle);
|
||||
cache.emplace(Key{PyObjectPtr{value}, dtype, device_name}, h->handle.Copy());
|
||||
}
|
||||
|
||||
void TFE_TensorHandleCache::Clear() {
|
||||
|
@ -76,7 +76,7 @@ struct TFE_TensorHandleCache {
|
||||
absl::string_view device_name) const;
|
||||
|
||||
void Insert(PyObject* value, tensorflow::DataType dtype,
|
||||
absl::string_view device_name, TFE_TensorHandle* handle);
|
||||
absl::string_view device_name, TFE_TensorHandle* h);
|
||||
|
||||
void Clear();
|
||||
|
||||
@ -87,13 +87,13 @@ struct TFE_TensorHandleCache {
|
||||
void DecrefUnrefAll() {
|
||||
for (const auto& p : cache) {
|
||||
Py_DECREF(static_cast<PyObject*>(std::get<0>(p.first)));
|
||||
p.second->Unref();
|
||||
TFE_DeleteTensorHandle(p.second);
|
||||
}
|
||||
}
|
||||
|
||||
// Not guarded by a mutex because the code is only used while the
|
||||
// GIL is held.
|
||||
absl::flat_hash_map<Key, tensorflow::TensorHandle*> cache;
|
||||
absl::flat_hash_map<Key, TFE_TensorHandle*> cache;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tape.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/compactptrset.h"
|
||||
@ -1903,18 +1904,28 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
|
||||
if (EagerTensor_CheckExact(tensor)) {
|
||||
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
|
||||
tensorflow::int64 id = PyEagerTensor_ID(tensor);
|
||||
tensorflow::DataType dtype =
|
||||
static_cast<tensorflow::DataType>(t->handle.DataType());
|
||||
if (dtype == tensorflow::DT_VARIANT) {
|
||||
return PyTapeTensor(id, dtype, tensor);
|
||||
}
|
||||
|
||||
tensorflow::Status status;
|
||||
tensorflow::TensorShape tensor_shape;
|
||||
const tensorflow::Status status = t->handle->Shape(&tensor_shape);
|
||||
int num_dims = t->handle.NumDims(&status);
|
||||
if (status.ok()) {
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
tensorflow::int64 dim_size = t->handle.Dim(i, &status);
|
||||
if (!status.ok()) break;
|
||||
tensor_shape.AddDim(dim_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
|
||||
return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
|
||||
tensorflow::TensorShape({}));
|
||||
} else {
|
||||
if (t->handle->dtype == tensorflow::DT_VARIANT) {
|
||||
return PyTapeTensor(id, t->handle->dtype, tensor);
|
||||
} else {
|
||||
return PyTapeTensor(id, t->handle->dtype, tensor_shape);
|
||||
}
|
||||
return PyTapeTensor(id, dtype, tensor_shape);
|
||||
}
|
||||
}
|
||||
tensorflow::int64 id = FastTensorId(tensor);
|
||||
@ -3857,16 +3868,21 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
|
||||
EncodeResult* result) {
|
||||
if (EagerTensor_CheckExact(arg)) {
|
||||
TFE_TensorHandle* t = EagerTensor_Handle(arg);
|
||||
tensorflow::TensorShape tensor_shape;
|
||||
TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape));
|
||||
|
||||
absl::StrAppend(&result->str, kDType, t->handle->dtype);
|
||||
|
||||
absl::StrAppend(&result->str, kDType,
|
||||
static_cast<tensorflow::DataType>(t->handle.DataType()));
|
||||
absl::StrAppend(&result->str, kShape);
|
||||
|
||||
tensorflow::Status status;
|
||||
int num_dims = t->handle.NumDims(&status);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
if (include_tensor_ranks_only) {
|
||||
absl::StrAppend(&result->str, tensor_shape.dim_sizes().size());
|
||||
absl::StrAppend(&result->str, num_dims);
|
||||
} else {
|
||||
for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
tensorflow::int64 dim_size = t->handle.Dim(i, &status);
|
||||
if (!status.ok()) return status;
|
||||
absl::StrAppend(&result->str, dim_size, kShapeDelim);
|
||||
}
|
||||
}
|
||||
|
@ -95,7 +95,8 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
|
||||
TensorHandle* handle;
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
|
||||
t, ctx->CanonicalDevice(device), ctx, &handle));
|
||||
arg = EagerTensorFromHandle(new TFE_TensorHandle(handle));
|
||||
arg = EagerTensorFromHandle(
|
||||
new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)});
|
||||
if (arg == nullptr) {
|
||||
Py_DECREF(lst);
|
||||
return errors::Internal("Unable to procure EagerTensor from Tensor.");
|
||||
@ -144,7 +145,7 @@ bool IsSingleNone(PyObject* obj) {
|
||||
tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
|
||||
const Device* expected_device,
|
||||
const Tensor** output_tensor) {
|
||||
auto handle = EagerTensor_Handle(eager_tensor)->handle;
|
||||
auto handle = EagerTensor_Handle(eager_tensor)->handle.Handle();
|
||||
Device* actual_device = handle->device();
|
||||
TF_RETURN_IF_ERROR(handle->Tensor(output_tensor));
|
||||
// actual_device may be nullptr, which implies local CPU.
|
||||
|
Loading…
Reference in New Issue
Block a user