address comment
This commit is contained in:
parent
89c73caf12
commit
88d46f6184
@ -1,3 +1,18 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/c/eager/dlpack.h"
|
#include "tensorflow/c/eager/dlpack.h"
|
||||||
#include "include/dlpack/dlpack.h" // TF:dlpack
|
#include "include/dlpack/dlpack.h" // TF:dlpack
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
@ -10,18 +25,15 @@
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
using tensorflow::Tensor;
|
|
||||||
using tensorflow::TensorHandleInterface;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct TFDLMTensor {
|
struct TFDLManagedTensorCtx {
|
||||||
TensorReference* handle;
|
TensorReference* handle;
|
||||||
DLManagedTensor tensor;
|
DLManagedTensor tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
TensorHandle* GetTensorHandleFromTFEHandle(TFE_TensorHandle* h,
|
|
||||||
TF_Status* status) {
|
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
"The passed in handle is a nullptr");
|
"The passed in handle is a nullptr");
|
||||||
@ -37,25 +49,6 @@ TensorHandle* GetTensorHandleFromTFEHandle(TFE_TensorHandle* h,
|
|||||||
"handle.");
|
"handle.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
|
||||||
TensorHandle* handle = GetTensorHandleFromTFEHandle(h, status);
|
|
||||||
|
|
||||||
if (handle->IsRemote()) {
|
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
|
||||||
"TFE_TensorHandleDevicePointer may not be called on a remote tensor "
|
|
||||||
"handle.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
|
|
||||||
if (device != nullptr) {
|
|
||||||
status->status = device->Sync();
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const tensorflow::Tensor* tensor;
|
const tensorflow::Tensor* tensor;
|
||||||
status->status = handle->Tensor(&tensor);
|
status->status = handle->Tensor(&tensor);
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
@ -64,13 +57,13 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
|||||||
return tensor;
|
return tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
void deleter(DLManagedTensor* arg) {
|
void DLManagedTensorDeleter(DLManagedTensor* arg) {
|
||||||
TFDLMTensor* owner = static_cast<TFDLMTensor*>(arg->manager_ctx);
|
TFDLManagedTensorCtx* owner = static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);
|
||||||
owner->handle->Unref();
|
owner->handle->Unref();
|
||||||
delete owner;
|
delete owner;
|
||||||
}
|
}
|
||||||
|
|
||||||
DLDataType getDLDataType(TF_DataType data_type, TF_Status* status) {
|
DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
|
||||||
DLDataType dtype;
|
DLDataType dtype;
|
||||||
dtype.lanes = 1;
|
dtype.lanes = 1;
|
||||||
dtype.bits = TF_DataTypeSize(data_type) * 8;
|
dtype.bits = TF_DataTypeSize(data_type) * 8;
|
||||||
@ -155,7 +148,7 @@ DLDataType getDLDataType(TF_DataType data_type, TF_Status* status) {
|
|||||||
return dtype;
|
return dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
DLContext getDLContext(TFE_TensorHandle* h, TF_Status* status) {
|
DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
DLContext ctx;
|
DLContext ctx;
|
||||||
const char* device_name = h->handle->DeviceName(&status->status);
|
const char* device_name = h->handle->DeviceName(&status->status);
|
||||||
DeviceNameUtils::ParsedName parsed_name;
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
@ -180,21 +173,21 @@ DLContext getDLContext(TFE_TensorHandle* h, TF_Status* status) {
|
|||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
DLManagedTensor* TFEHandleToTFDLMTensor(TFE_TensorHandle* h,
|
DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
const Tensor* tensor = GetTensorFromHandle(h, status);
|
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||||
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||||
TFDLMTensor* tfDLMTensor(new TFDLMTensor);
|
TFDLManagedTensorCtx* tfDLMTensor(new TFDLManagedTensorCtx);
|
||||||
TensorReference* tensor_ref =
|
TensorReference* tensor_ref =
|
||||||
new TensorReference(*tensor); // This will call buf_->Ref()
|
new TensorReference(*tensor); // This will call buf_->Ref()
|
||||||
tfDLMTensor->handle = tensor_ref;
|
tfDLMTensor->handle = tensor_ref;
|
||||||
tfDLMTensor->tensor.manager_ctx = tfDLMTensor;
|
tfDLMTensor->tensor.manager_ctx = tfDLMTensor;
|
||||||
tfDLMTensor->tensor.deleter = &deleter;
|
tfDLMTensor->tensor.deleter = &DLManagedTensorDeleter;
|
||||||
tfDLMTensor->tensor.dl_tensor.ctx = getDLContext(h, status);
|
tfDLMTensor->tensor.dl_tensor.ctx = GetDLContext(h, status);
|
||||||
int ndim = tensor->dims();
|
int ndim = tensor->dims();
|
||||||
tfDLMTensor->tensor.dl_tensor.ndim = ndim;
|
tfDLMTensor->tensor.dl_tensor.ndim = ndim;
|
||||||
tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||||
tfDLMTensor->tensor.dl_tensor.dtype = getDLDataType(data_type, status);
|
tfDLMTensor->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
|
||||||
|
|
||||||
int64_t* shape_arr = new int64_t[ndim];
|
int64_t* shape_arr = new int64_t[ndim];
|
||||||
for (int i = 0; i < ndim; i++) {
|
for (int i = 0; i < ndim; i++) {
|
||||||
@ -313,7 +306,7 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
DLManagedTensor* tfdlmtensor = TFEHandleToTFDLMTensor(h, status);
|
DLManagedTensor* tfdlmtensor = TFEHandleToTFDLManagedTensorCtx(h, status);
|
||||||
return static_cast<void*>(tfdlmtensor);
|
return static_cast<void*>(tfdlmtensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,13 +1,25 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_C_DLPACK_H_
|
#ifndef TENSORFLOW_C_DLPACK_H_
|
||||||
#define TENSORFLOW_C_DLPACK_H_
|
#define TENSORFLOW_C_DLPACK_H_
|
||||||
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
const char* const kDlTensorCapsuleName = "dltensor";
|
const char* const kDlTensorCapsuleName = "dltensor";
|
||||||
@ -19,8 +31,5 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status);
|
|||||||
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
} /* end extern "C" */
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_DLPACK_H_
|
#endif // TENSORFLOW_C_DLPACK_H_
|
||||||
|
@ -1069,8 +1069,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
"Note that a DLPack tensor may be consumed at most once.",
|
"Note that a DLPack tensor may be consumed at most once.",
|
||||||
absl::string_view(pycapsule.name()));
|
absl::string_view(pycapsule.name()));
|
||||||
}
|
}
|
||||||
TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
|
TFE_TensorHandle* thandle =
|
||||||
static_cast<void*>(pycapsule), status.get());
|
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
|
||||||
|
|
||||||
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
|
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
|
||||||
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
|
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
|
||||||
|
Loading…
Reference in New Issue
Block a user