address comment

This commit is contained in:
VoVAllen 2020-02-19 11:59:16 +00:00
parent 89c73caf12
commit 88d46f6184
3 changed files with 47 additions and 45 deletions

View File

@ -1,3 +1,18 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/dlpack.h" #include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // TF:dlpack #include "include/dlpack/dlpack.h" // TF:dlpack
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
@ -10,18 +25,15 @@
namespace tensorflow { namespace tensorflow {
using tensorflow::Tensor;
using tensorflow::TensorHandleInterface;
namespace { namespace {
struct TFDLMTensor { struct TFDLManagedTensorCtx {
TensorReference* handle; TensorReference* handle;
DLManagedTensor tensor; DLManagedTensor tensor;
}; };
TensorHandle* GetTensorHandleFromTFEHandle(TFE_TensorHandle* h,
TF_Status* status) { const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) { if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr"); "The passed in handle is a nullptr");
@ -37,25 +49,6 @@ TensorHandle* GetTensorHandleFromTFEHandle(TFE_TensorHandle* h,
"handle."); "handle.");
return nullptr; return nullptr;
} }
return handle;
}
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
TensorHandle* handle = GetTensorHandleFromTFEHandle(h, status);
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDevicePointer may not be called on a remote tensor "
"handle.");
return nullptr;
}
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
if (device != nullptr) {
status->status = device->Sync();
if (!status->status.ok()) {
return nullptr;
}
}
const tensorflow::Tensor* tensor; const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor); status->status = handle->Tensor(&tensor);
if (!status->status.ok()) { if (!status->status.ok()) {
@ -64,13 +57,13 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
return tensor; return tensor;
}; };
void deleter(DLManagedTensor* arg) { void DLManagedTensorDeleter(DLManagedTensor* arg) {
TFDLMTensor* owner = static_cast<TFDLMTensor*>(arg->manager_ctx); TFDLManagedTensorCtx* owner = static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);
owner->handle->Unref(); owner->handle->Unref();
delete owner; delete owner;
} }
DLDataType getDLDataType(TF_DataType data_type, TF_Status* status) { DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
DLDataType dtype; DLDataType dtype;
dtype.lanes = 1; dtype.lanes = 1;
dtype.bits = TF_DataTypeSize(data_type) * 8; dtype.bits = TF_DataTypeSize(data_type) * 8;
@ -155,7 +148,7 @@ DLDataType getDLDataType(TF_DataType data_type, TF_Status* status) {
return dtype; return dtype;
} }
DLContext getDLContext(TFE_TensorHandle* h, TF_Status* status) { DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx; DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status); const char* device_name = h->handle->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParsedName parsed_name;
@ -180,21 +173,21 @@ DLContext getDLContext(TFE_TensorHandle* h, TF_Status* status) {
return ctx; return ctx;
} }
DLManagedTensor* TFEHandleToTFDLMTensor(TFE_TensorHandle* h, DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
const Tensor* tensor = GetTensorFromHandle(h, status); const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype()); TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TFDLMTensor* tfDLMTensor(new TFDLMTensor); TFDLManagedTensorCtx* tfDLMTensor(new TFDLManagedTensorCtx);
TensorReference* tensor_ref = TensorReference* tensor_ref =
new TensorReference(*tensor); // This will call buf_->Ref() new TensorReference(*tensor); // This will call buf_->Ref()
tfDLMTensor->handle = tensor_ref; tfDLMTensor->handle = tensor_ref;
tfDLMTensor->tensor.manager_ctx = tfDLMTensor; tfDLMTensor->tensor.manager_ctx = tfDLMTensor;
tfDLMTensor->tensor.deleter = &deleter; tfDLMTensor->tensor.deleter = &DLManagedTensorDeleter;
tfDLMTensor->tensor.dl_tensor.ctx = getDLContext(h, status); tfDLMTensor->tensor.dl_tensor.ctx = GetDLContext(h, status);
int ndim = tensor->dims(); int ndim = tensor->dims();
tfDLMTensor->tensor.dl_tensor.ndim = ndim; tfDLMTensor->tensor.dl_tensor.ndim = ndim;
tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
tfDLMTensor->tensor.dl_tensor.dtype = getDLDataType(data_type, status); tfDLMTensor->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
int64_t* shape_arr = new int64_t[ndim]; int64_t* shape_arr = new int64_t[ndim];
for (int i = 0; i < ndim; i++) { for (int i = 0; i < ndim; i++) {
@ -313,7 +306,7 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
} }
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
DLManagedTensor* tfdlmtensor = TFEHandleToTFDLMTensor(h, status); DLManagedTensor* tfdlmtensor = TFEHandleToTFDLManagedTensorCtx(h, status);
return static_cast<void*>(tfdlmtensor); return static_cast<void*>(tfdlmtensor);
} }

View File

@ -1,13 +1,25 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_DLPACK_H_ #ifndef TENSORFLOW_C_DLPACK_H_
#define TENSORFLOW_C_DLPACK_H_ #define TENSORFLOW_C_DLPACK_H_
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#ifdef __cplusplus
extern "C" {
#endif
namespace tensorflow { namespace tensorflow {
const char* const kDlTensorCapsuleName = "dltensor"; const char* const kDlTensorCapsuleName = "dltensor";
@ -19,8 +31,5 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status);
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
} // namespace tensorflow } // namespace tensorflow
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_DLPACK_H_ #endif // TENSORFLOW_C_DLPACK_H_

View File

@ -1069,9 +1069,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
"Note that a DLPack tensor may be consumed at most once.", "Note that a DLPack tensor may be consumed at most once.",
absl::string_view(pycapsule.name())); absl::string_view(pycapsule.name()));
} }
TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack( TFE_TensorHandle* thandle =
static_cast<void*>(pycapsule), status.get()); tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor"); PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());