Add layer of indirection for Tensor & TensorHandle

We add the TensorInterface & TensorHandleInterface classes and keep them
as the sole member of TF_Tensor and TFE_TensorHandle structs to keep
those structs simple. This allows us to keep most of the C API functions
as simple wrappers around C++ classes.

PiperOrigin-RevId: 288903948
Change-Id: I9f4d8914c447145df63c8518bcde60656f7098f9
This commit is contained in:
Gaurav Jain 2020-01-09 08:40:14 -08:00 committed by TensorFlower Gardener
parent 243685515f
commit 96f40ae009
18 changed files with 427 additions and 155 deletions

View File

@ -28,6 +28,7 @@ tf_cuda_library(
"c_api_experimental.h",
"c_api_internal.cc",
"c_api_internal.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api.h"],
copts = tf_copts() + tfe_xla_copts(),
@ -93,6 +94,7 @@ filegroup(
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
"tensor_handle_interface.h",
],
visibility = [
"//tensorflow/core:__pkg__",
@ -102,7 +104,10 @@ filegroup(
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api_experimental.h"],
srcs = [
"c_api_experimental.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api_internal.h"],
visibility = [
"//learning/deepmind/courier:__subpackages__",

View File

@ -630,7 +630,8 @@ tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
}
const std::string& type_attr = input_def.type_attr();
if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
op->operation.MutableAttrs()->Set(
type_attr, static_cast<tensorflow::DataType>(input->handle.DataType()));
ictx->attrs.insert(type_attr);
}
return tensorflow::Status::OK();
@ -671,13 +672,16 @@ tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
if (!input_def.type_list_attr().empty()) {
std::vector<tensorflow::DataType> dtypes(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs[i]->handle->dtype;
dtypes[i] =
static_cast<const tensorflow::DataType>(inputs[i]->handle.DataType());
}
OpInferMixedTypeInputListAttrs(op, input_def, dtypes);
} else if (!input_def.type_attr().empty() &&
!input_def.number_attr().empty()) {
OpInferSingleTypeInputListAttrs(op, input_def, inputs[0]->handle->dtype,
num_inputs);
OpInferSingleTypeInputListAttrs(
op, input_def,
static_cast<const tensorflow::DataType>(inputs[0]->handle.DataType()),
num_inputs);
} else {
return tensorflow::errors::InvalidArgument("Invalid input list definition");
}
@ -745,12 +749,9 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
if (ctx->context->remote_device_mgr()) {
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
}
return list;
TF_DeviceList* l = new TF_DeviceList;
ctx->context->ListDevices(&l->response);
return l;
}
void TFE_ContextClearCaches(TFE_Context* ctx) {
@ -886,138 +887,209 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
if (h == nullptr) return;
tensorflow::profiler::TraceMe activity(
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
<< h->handle;
if (h->handle) {
h->handle->Unref();
}
delete h;
}
tensorflow::TensorHandleInterface::~TensorHandleInterface() {
VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
<< handle_;
if (handle_) {
handle_->Unref();
}
}
bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
if (handle_ == nullptr) {
*status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return false;
}
return true;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->handle->dtype);
return h->handle.DataType();
}
TF_DataType tensorflow::TensorHandleInterface::DataType() const {
return static_cast<TF_DataType>(handle_->dtype);
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
return h->handle.NumDims(&status->status);
}
int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
if (!IsValid(status)) {
return -1;
}
int result;
status->status = h->handle->NumDims(&result);
*status = handle_->NumDims(&result);
return result;
}
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
return h->handle.NumElements(&status->status);
}
int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result;
status->status = h->handle->NumElements(&result);
*status = handle_->NumElements(&result);
return result;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
return h->handle.Dim(dim_index, &status->status);
}
int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
Status* status) const {
if (!IsValid(status)) {
return -1;
}
tensorflow::int64 result;
status->status = h->handle->Dim(dim_index, &result);
*status = handle_->Dim(dim_index, &result);
return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = h->handle->op_device();
return h->handle.DeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::DeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = h->handle->device();
return h->handle.BackingDeviceName(&status->status);
}
const char* tensorflow::TensorHandleInterface::BackingDeviceName(
Status* status) const {
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr || !h->handle.IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
h->handle->Ref();
return h->handle.Copy();
}
return new TFE_TensorHandle(h->handle);
TFE_TensorHandle* tensorflow::TensorHandleInterface::Copy() {
handle_->Ref();
return new TFE_TensorHandle{TensorHandleInterface(handle_)};
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle = h->handle;
return h->handle.Resolve(&status->status);
}
TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!IsValid(status)) {
return nullptr;
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
if (handle->IsRemote()) {
if (handle_->IsRemote()) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice(
handle, handle->Context(), &handle->Context()->Executor(),
handle->Context()->HostCPU(), false, &h_cpu);
if (!status->status.ok()) {
*status = EagerCopyToDevice(handle_, handle_->Context(),
&handle_->Context()->Executor(),
handle_->Context()->HostCPU(), false, &h_cpu);
if (!status->ok()) {
return nullptr;
}
status->status = h_cpu->Tensor(&t);
if (!status->status.ok()) {
*status = h_cpu->Tensor(&t);
if (!status->ok()) {
h_cpu->Unref();
return nullptr;
}
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, &status->status);
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
h_cpu->Unref();
return retval;
} else {
tensorflow::Tensor tensor;
if (IsCPU(handle->device())) {
if (IsCPU(handle_->device())) {
const tensorflow::Tensor* src = nullptr;
status->status = handle->Tensor(&src);
if (!status->status.ok()) return nullptr;
*status = handle_->Tensor(&src);
if (!status->ok()) return nullptr;
tensor = *src;
} else {
tensorflow::EagerContext* ctx = handle->Context();
tensorflow::EagerContext* ctx = handle_->Context();
CHECK_NE(ctx, nullptr);
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
if (!status->status.ok()) return nullptr;
*status = handle_->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
if (!status->ok()) return nullptr;
}
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
return tensorflow::TF_TensorFromTensor(tensor, status);
}
}
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr || !h->handle.IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle = h->handle;
tensorflow::TensorHandle* handle = h->handle.Handle();
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
@ -1078,7 +1150,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
if (!status->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle(ret_handle);
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(ret_handle)};
}
// This function will block till the operation that produces `h` has
@ -1086,12 +1158,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
// bytes of the memory pointed to by the device pointer returned above.
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
if (h == nullptr || !h->handle.IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return 0;
}
tensorflow::TensorHandle* handle = h->handle;
tensorflow::TensorHandle* handle = h->handle.Handle();
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
@ -1135,16 +1207,20 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
op->operation.AddInput(input->handle);
if (op->inference_ctx) {
status->status = OpInferSingleInputAttrs(op, input);
return op->AddInput(input, status);
}
void TFE_Op::AddInput(TFE_TensorHandle* input, TF_Status* status) {
operation.AddInput(input->handle.Handle());
if (inference_ctx) {
status->status = OpInferSingleInputAttrs(this, input);
}
}
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
for (int i = 0; i < num_inputs; ++i) {
op->operation.AddInput(inputs[i]->handle);
op->operation.AddInput(inputs[i]->handle.Handle());
}
if (op->inference_ctx) {
status->status = OpInferInputListAttrs(op, inputs, num_inputs);
@ -1382,14 +1458,20 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
VLOG(1) << "Calling TFE_Execute() on op " << op;
op->Execute(retvals, num_retvals, status);
}
void TFE_Op::Execute(TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
status->status = tensorflow::EagerExecute(&op->operation,
handle_retvals.data(), num_retvals);
status->status =
tensorflow::EagerExecute(&operation, handle_retvals.data(), num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
retvals[i] = new TFE_TensorHandle{
tensorflow::TensorHandleInterface(handle_retvals[i])};
}
}
@ -1403,11 +1485,11 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
if (!status->status.ok()) {
return nullptr;
}
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
&ctx->context->Executor(),
device, false, &handle);
status->status = tensorflow::EagerCopyToDevice(
h->handle.Handle(), ctx->context, &ctx->context->Executor(), device,
false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)};
}
return nullptr;
}

View File

@ -28,19 +28,22 @@ using tensorflow::string;
namespace {
std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
TF_Status* status) {
std::vector<int64> TensorShapeAsVector(const tensorflow::TensorHandle& handle,
tensorflow::Status* status) {
std::vector<int64> shape;
int rank = TFE_TensorHandleNumDims(handle, status);
if (TF_GetCode(status) != TF_OK) {
int rank = -1;
*status = handle.NumDims(&rank);
if (!status->ok()) {
return shape;
}
shape.reserve(rank);
for (int i = 0; i < rank; ++i) {
shape.push_back(TFE_TensorHandleDim(handle, i, status));
if (TF_GetCode(status) != TF_OK) {
tensorflow::int64 dim;
*status = handle.Dim(i, &dim);
if (!status->ok()) {
return shape;
}
shape.push_back(dim);
}
return shape;
}
@ -51,14 +54,19 @@ extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* h, TF_Status* status) {
return h->handle.TensorDebugInfo(&status->status);
}
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
Status* status) {
const tensorflow::Tensor* tensor;
status->status = h->handle->Tensor(&tensor);
if (TF_GetCode(status) != TF_OK) {
*status = handle_->Tensor(&tensor);
if (!status->ok()) {
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = h->handle->device();
tensorflow::Device* device = handle_->device();
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =
@ -67,15 +75,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape;
status->status = shape_fn(*tensor, &padded_shape);
if (!status->status.ok()) {
*status = shape_fn(*tensor, &padded_shape);
if (!status->ok()) {
return nullptr;
}
if (VLOG_IS_ON(3)) {
std::vector<int64> shape_to_log = TensorShapeAsVector(h, status);
if (!status->status.ok()) {
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
if (!status->ok()) {
// Ignore the status here as we are simply logging.
status->status = tensorflow::Status::OK();
*status = tensorflow::Status::OK();
} else {
VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is "
@ -88,7 +96,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
// Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers).
status->status = tensorflow::errors::InvalidArgument(
*status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString());
return nullptr;
@ -100,13 +108,13 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) {
status->status = tensorflow::errors::InvalidArgument(
*status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString());
return nullptr;
}
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
status->status = tensorflow::errors::InvalidArgument(
*status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString());
return nullptr;
@ -131,15 +139,15 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
dev_dims.push_back(padded_shape.dimensions(dim_index));
}
}
status->status = tensorflow::Status::OK();
*status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims);
}
#endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(h, status);
if (TF_GetCode(status) != TF_OK) {
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
if (!status->ok()) {
return nullptr;
}
return new TFE_TensorDebugInfo(dev_dims);

View File

@ -41,7 +41,7 @@ void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
}
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(h->handle);
op->operation.ConsumeInput(h->handle.Handle());
}
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -91,7 +92,6 @@ struct TFE_Context {
};
struct TFE_TensorHandle {
explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {}
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
TF_Status* s) {
tensorflow::TensorHandle* handle;
@ -99,10 +99,10 @@ struct TFE_TensorHandle {
if (!s->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle(handle);
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)};
}
tensorflow::TensorHandle* handle;
tensorflow::TensorHandleInterface handle;
};
struct TFE_TensorDebugInfo {
@ -144,6 +144,9 @@ struct TFE_Op {
nullptr);
}
void AddInput(TFE_TensorHandle* input, TF_Status* status);
void Execute(TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status);
TFE_Context* ctx;
tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;

View File

@ -0,0 +1,52 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
namespace tensorflow {
class TensorHandleInterface {
public:
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
~TensorHandleInterface();
bool IsValid(Status* status) const;
TF_DataType DataType() const;
int NumDims(Status* status) const;
int64_t NumElements(Status* status) const;
int64_t Dim(int dim_index, Status* status) const;
const char* DeviceName(Status* status) const;
const char* BackingDeviceName(Status* status) const;
TFE_TensorHandle* Copy();
TF_Tensor* Resolve(Status* status);
TFE_TensorDebugInfo* TensorDebugInfo(Status* status);
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
TensorHandle* Handle() { return handle_; }
private:
TensorHandle* handle_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_

View File

@ -103,9 +103,9 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
}
TF_Tensor* ret =
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf)};
TF_Tensor* ret = new TF_Tensor{tensorflow::TensorInterface(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf))};
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) {
@ -115,37 +115,23 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
return ret;
}
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return tensor;
}
return nullptr;
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {
return t->tensor.CanMove() ? t : nullptr;
}
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
TF_DataType TF_TensorType(const TF_Tensor* t) {
return static_cast<TF_DataType>(t->tensor.dtype());
}
TF_DataType TF_TensorType(const TF_Tensor* t) { return t->tensor.Type(); }
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); }
int TF_NumDims(const TF_Tensor* t) { return t->tensor.NumDims(); }
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
return static_cast<int64_t>(t->tensor.dim_size(dim_index));
return t->tensor.Dim(dim_index);
}
size_t TF_TensorByteSize(const TF_Tensor* t) {
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
}
size_t TF_TensorByteSize(const TF_Tensor* t) { return t->tensor.ByteSize(); }
void* TF_TensorData(const TF_Tensor* t) {
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
}
void* TF_TensorData(const TF_Tensor* t) { return t->tensor.Data(); }
int64_t TF_TensorElementCount(const TF_Tensor* t) {
int64_t result = 1;
@ -160,15 +146,60 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
TF_Tensor* to, const int64_t* new_dims,
int num_new_dims, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
Status cc_status(
to->tensor.BitcastFrom(from->tensor, type, new_dims, num_new_dims));
Set_TF_Status_from_Status(status, cc_status);
}
namespace tensorflow {
bool TensorInterface::CanMove() const {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor_);
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
buf->OwnsMemory()) {
return true;
}
return false;
}
TF_DataType TensorInterface::Type() const {
return static_cast<TF_DataType>(tensor_.dtype());
}
int TensorInterface::NumDims() const { return tensor_.dims(); }
int64_t TensorInterface::Dim(int dim_index) const {
return static_cast<int64_t>(tensor_.dim_size(dim_index));
}
int64_t TensorInterface::NumElements() const {
return static_cast<int64_t>(tensor_.NumElements());
}
size_t TensorInterface::ByteSize() const {
return tensorflow::TensorCApi::Buffer(tensor_)->size();
}
void* TensorInterface::Data() const {
return tensorflow::TensorCApi::Buffer(tensor_)->data();
}
Status TensorInterface::BitcastFrom(const TensorInterface& from,
TF_DataType type, const int64_t* new_dims,
int num_new_dims) {
tensorflow::TensorShape s;
for (int i = 0; i < num_new_dims; ++i) {
s.AddDim(new_dims[i]);
}
Status cc_status(to->tensor.BitcastFrom(
from->tensor, static_cast<tensorflow::DataType>(type), s));
Set_TF_Status_from_Status(status, cc_status);
return tensor_.BitcastFrom(from.tensor_,
static_cast<tensorflow::DataType>(type), s);
}
} // namespace tensorflow
// --------------------------------------------------------------------------
void StringEncode(const char* src, size_t src_len, char* dst) {
dst = tensorflow::core::EncodeVarint64(dst, src_len);
@ -332,31 +363,34 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
if (src->tensor.dtype() == DT_RESOURCE) {
if (src->tensor.dims() != 0) {
return src->tensor.ToTensor(dst);
}
Status TensorInterface::ToTensor(Tensor* dst) const {
if (tensor_.dtype() == DT_RESOURCE) {
if (tensor_.dims() != 0) {
return InvalidArgument(
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
"shape ",
src->tensor.shape().DebugString());
tensor_.shape().DebugString());
}
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape());
*dst = Tensor(tensorflow::DT_RESOURCE, tensor_.shape());
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
string(static_cast<const char*>(TF_TensorData(src)),
TF_TensorByteSize(src)))) {
string(static_cast<const char*>(Data()), ByteSize()))) {
return InvalidArgument(
"Malformed TF_RESOUCE tensor: unable to parse resource handle");
}
return Status::OK();
}
if (src->tensor.dtype() != DT_STRING) {
*dst = src->tensor;
if (tensor_.dtype() != DT_STRING) {
*dst = tensor_;
return Status::OK();
}
// TF_STRING tensors require copying since Tensor class expects a sequence of
// string objects.
const tensorflow::int64 num_elements = src->tensor.NumElements();
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
const size_t src_size = TF_TensorByteSize(src);
const tensorflow::int64 num_elements = tensor_.NumElements();
const char* input = reinterpret_cast<const char*>(Data());
const size_t src_size = ByteSize();
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
num_elements) {
return InvalidArgument(
@ -365,7 +399,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
const char* limit = input + src_size;
*dst = Tensor(src->tensor.dtype(), src->tensor.shape());
*dst = Tensor(tensor_.dtype(), tensor_.shape());
auto dstarray = dst->flat<tstring>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset =
@ -384,8 +418,12 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
return Status::OK();
}
bool TensorInterface::CopyFrom(const Tensor& other, const TensorShape& shape) {
return tensor_.CopyFrom(other, shape);
}
bool TensorInterface::IsAligned() const { return tensor_.IsAligned(); }
} // namespace tensorflow
bool TF_TensorIsAligned(const TF_Tensor* tensor) {
return tensor->tensor.IsAligned();
}
bool TF_TensorIsAligned(const TF_Tensor* t) { return t->tensor.IsAligned(); }

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/framework/tensor_shape.h"
// Internal structures used by the C API. These are likely to change and should
@ -28,7 +29,7 @@ limitations under the License.
// passed to or returned from C functions *by pointer*. Otherwise, changes to
// its internal structure will break the C API's binary interface.
typedef struct TF_Tensor {
::tensorflow::Tensor tensor;
tensorflow::TensorInterface tensor;
} TF_Tensor;
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
@ -83,4 +84,5 @@ void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
// a different Allocator as `arg`.
void deallocate_buffer(void* data, size_t len, void* arg);
} // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_

View File

@ -637,6 +637,7 @@ tf_cuda_library(
"//tensorflow/core/framework:shared_ptr_variant.h",
"//tensorflow/core/framework:stats_aggregator.h",
"//tensorflow/core/framework:tensor.h",
"//tensorflow/core/framework:tensor_interface.h",
"//tensorflow/core/framework:tensor_shape.h",
"//tensorflow/core/framework:tensor_slice.h",
"//tensorflow/core/framework:tensor_types.h",

View File

@ -388,6 +388,14 @@ std::vector<const FunctionDef*> EagerContext::ListRegisteredFunctions() {
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
void EagerContext::ListDevices(
std::vector<tensorflow::DeviceAttributes>* devices) {
local_device_mgr()->ListDeviceAttributes(devices);
if (remote_device_mgr()) {
remote_device_mgr()->ListDeviceAttributes(devices);
}
}
void EagerContext::StartStep() {
mutex_lock ml(metadata_mu_);
num_active_steps_++;

View File

@ -251,6 +251,8 @@ class EagerContext : public core::RefCounted {
RunMetadata* RunMetadataProto() { return &run_metadata_; }
void ClearRunMetadata() EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices);
void StartStep();
void EndStep();
ScopedStepContainer* StepContainer();

View File

@ -345,6 +345,7 @@ filegroup(
"stats_aggregator.h",
"tensor.cc",
"tensor.h",
"tensor_interface.h",
"tensor_reference.h",
"tensor_shape.cc",
"tensor_shape.h",
@ -902,6 +903,7 @@ exports_files(
"resource_handle.h",
"shape_inference_testutil.h",
"tensor.h",
"tensor_interface.h",
"tensor_shape.h",
"tensor_testutil.h",
"tensor_types.h",

View File

@ -0,0 +1,54 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/tensor.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
namespace tensorflow {
class TensorInterface {
public:
TensorInterface() {}
explicit TensorInterface(Tensor t) : tensor_(std::move(t)) {}
TF_DataType Type() const;
int NumDims() const;
int64_t Dim(int dim_index) const;
int64_t NumElements() const;
size_t ByteSize() const;
void* Data() const;
bool IsAligned() const;
Status ToTensor(Tensor* dst) const;
bool CopyFrom(const Tensor& other, const TensorShape& shape);
Status BitcastFrom(const TensorInterface& from, TF_DataType type,
const int64_t* new_dims, int num_new_dims);
bool CanMove() const;
private:
Tensor tensor_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_

View File

@ -90,7 +90,7 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
.c_str());
return nullptr;
}
return new TFE_TensorHandle(handle);
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)};
}
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
@ -268,7 +268,7 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* value,
return nullptr;
}
CHECK_NE(handle, nullptr);
return new TFE_TensorHandle(handle);
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)};
}
TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,

View File

@ -48,17 +48,15 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
}
scalar_cache_hits->GetCell()->IncrementBy(1);
auto* handle = it->second;
handle->Ref();
return new TFE_TensorHandle(handle);
auto* h = it->second;
return h->handle.Copy();
}
void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
absl::string_view device_name,
TFE_TensorHandle* handle) {
TFE_TensorHandle* h) {
Py_INCREF(value);
handle->handle->Ref();
cache.emplace(Key{PyObjectPtr{value}, dtype, device_name}, handle->handle);
cache.emplace(Key{PyObjectPtr{value}, dtype, device_name}, h->handle.Copy());
}
void TFE_TensorHandleCache::Clear() {

View File

@ -76,7 +76,7 @@ struct TFE_TensorHandleCache {
absl::string_view device_name) const;
void Insert(PyObject* value, tensorflow::DataType dtype,
absl::string_view device_name, TFE_TensorHandle* handle);
absl::string_view device_name, TFE_TensorHandle* h);
void Clear();
@ -87,13 +87,13 @@ struct TFE_TensorHandleCache {
void DecrefUnrefAll() {
for (const auto& p : cache) {
Py_DECREF(static_cast<PyObject*>(std::get<0>(p.first)));
p.second->Unref();
TFE_DeleteTensorHandle(p.second);
}
}
// Not guarded by a mutex because the code is only used while the
// GIL is held.
absl::flat_hash_map<Key, tensorflow::TensorHandle*> cache;
absl::flat_hash_map<Key, TFE_TensorHandle*> cache;
};
} // namespace tensorflow

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tape.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
@ -1903,18 +1904,28 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(t->handle.DataType());
if (dtype == tensorflow::DT_VARIANT) {
return PyTapeTensor(id, dtype, tensor);
}
tensorflow::Status status;
tensorflow::TensorShape tensor_shape;
const tensorflow::Status status = t->handle->Shape(&tensor_shape);
int num_dims = t->handle.NumDims(&status);
if (status.ok()) {
for (int i = 0; i < num_dims; ++i) {
tensorflow::int64 dim_size = t->handle.Dim(i, &status);
if (!status.ok()) break;
tensor_shape.AddDim(dim_size);
}
}
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
tensorflow::TensorShape({}));
} else {
if (t->handle->dtype == tensorflow::DT_VARIANT) {
return PyTapeTensor(id, t->handle->dtype, tensor);
} else {
return PyTapeTensor(id, t->handle->dtype, tensor_shape);
}
return PyTapeTensor(id, dtype, tensor_shape);
}
}
tensorflow::int64 id = FastTensorId(tensor);
@ -3857,16 +3868,21 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
EncodeResult* result) {
if (EagerTensor_CheckExact(arg)) {
TFE_TensorHandle* t = EagerTensor_Handle(arg);
tensorflow::TensorShape tensor_shape;
TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape));
absl::StrAppend(&result->str, kDType, t->handle->dtype);
absl::StrAppend(&result->str, kDType,
static_cast<tensorflow::DataType>(t->handle.DataType()));
absl::StrAppend(&result->str, kShape);
tensorflow::Status status;
int num_dims = t->handle.NumDims(&status);
if (!status.ok()) return status;
if (include_tensor_ranks_only) {
absl::StrAppend(&result->str, tensor_shape.dim_sizes().size());
absl::StrAppend(&result->str, num_dims);
} else {
for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
for (int i = 0; i < num_dims; ++i) {
tensorflow::int64 dim_size = t->handle.Dim(i, &status);
if (!status.ok()) return status;
absl::StrAppend(&result->str, dim_size, kShapeDelim);
}
}

View File

@ -95,7 +95,8 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
TensorHandle* handle;
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
t, ctx->CanonicalDevice(device), ctx, &handle));
arg = EagerTensorFromHandle(new TFE_TensorHandle(handle));
arg = EagerTensorFromHandle(
new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)});
if (arg == nullptr) {
Py_DECREF(lst);
return errors::Internal("Unable to procure EagerTensor from Tensor.");
@ -144,7 +145,7 @@ bool IsSingleNone(PyObject* obj) {
tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
const Device* expected_device,
const Tensor** output_tensor) {
auto handle = EagerTensor_Handle(eager_tensor)->handle;
auto handle = EagerTensor_Handle(eager_tensor)->handle.Handle();
Device* actual_device = handle->device();
TF_RETURN_IF_ERROR(handle->Tensor(output_tensor));
// actual_device may be nullptr, which implies local CPU.