From 88d46f618496e294020b926d158dca59257dbb02 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Wed, 19 Feb 2020 11:59:16 +0000 Subject: [PATCH] address comment --- tensorflow/c/eager/dlpack.cc | 63 ++++++++++++++------------------ tensorflow/c/eager/dlpack.h | 23 ++++++++---- tensorflow/python/tfe_wrapper.cc | 6 +-- 3 files changed, 47 insertions(+), 45 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 6a927a09286..40186f39947 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -1,3 +1,18 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "tensorflow/c/eager/dlpack.h" #include "include/dlpack/dlpack.h" // TF:dlpack #include "tensorflow/c/eager/c_api_internal.h" @@ -10,18 +25,15 @@ namespace tensorflow { -using tensorflow::Tensor; -using tensorflow::TensorHandleInterface; - namespace { -struct TFDLMTensor { +struct TFDLManagedTensorCtx { TensorReference* handle; DLManagedTensor tensor; }; -TensorHandle* GetTensorHandleFromTFEHandle(TFE_TensorHandle* h, - TF_Status* status) { + +const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || !h->handle->IsValid(&status->status)) { status->status = tensorflow::errors::InvalidArgument( "The passed in handle is a nullptr"); @@ -37,25 +49,6 @@ TensorHandle* GetTensorHandleFromTFEHandle(TFE_TensorHandle* h, "handle."); return nullptr; } - return handle; -} - -const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { - TensorHandle* handle = GetTensorHandleFromTFEHandle(h, status); - - if (handle->IsRemote()) { - status->status = tensorflow::errors::InvalidArgument( - "TFE_TensorHandleDevicePointer may not be called on a remote tensor " - "handle."); - return nullptr; - } - tensorflow::Device* device(absl::get(handle->device())); - if (device != nullptr) { - status->status = device->Sync(); - if (!status->status.ok()) { - return nullptr; - } - } const tensorflow::Tensor* tensor; status->status = handle->Tensor(&tensor); if (!status->status.ok()) { @@ -64,13 +57,13 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { return tensor; }; -void deleter(DLManagedTensor* arg) { - TFDLMTensor* owner = static_cast(arg->manager_ctx); +void DLManagedTensorDeleter(DLManagedTensor* arg) { + TFDLManagedTensorCtx* owner = static_cast(arg->manager_ctx); owner->handle->Unref(); delete owner; } -DLDataType getDLDataType(TF_DataType data_type, TF_Status* status) { +DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { DLDataType dtype; dtype.lanes = 1; dtype.bits = TF_DataTypeSize(data_type) * 8; @@ -155,7 +148,7 @@ DLDataType getDLDataType(TF_DataType data_type, TF_Status* status) { return dtype; } -DLContext getDLContext(TFE_TensorHandle* h, TF_Status* status) { +DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { DLContext ctx; const char* device_name = h->handle->DeviceName(&status->status); DeviceNameUtils::ParsedName parsed_name; @@ -180,21 +173,21 @@ DLContext getDLContext(TFE_TensorHandle* h, TF_Status* status) { return ctx; } -DLManagedTensor* TFEHandleToTFDLMTensor(TFE_TensorHandle* h, +DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, TF_Status* status) { const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast(tensor->dtype()); - TFDLMTensor* tfDLMTensor(new TFDLMTensor); + TFDLManagedTensorCtx* tfDLMTensor(new TFDLManagedTensorCtx); TensorReference* tensor_ref = new TensorReference(*tensor); // This will call buf_->Ref() tfDLMTensor->handle = tensor_ref; tfDLMTensor->tensor.manager_ctx = tfDLMTensor; - tfDLMTensor->tensor.deleter = &deleter; - tfDLMTensor->tensor.dl_tensor.ctx = getDLContext(h, status); + tfDLMTensor->tensor.deleter = &DLManagedTensorDeleter; + tfDLMTensor->tensor.dl_tensor.ctx = GetDLContext(h, status); int ndim = tensor->dims(); tfDLMTensor->tensor.dl_tensor.ndim = ndim; tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - tfDLMTensor->tensor.dl_tensor.dtype = getDLDataType(data_type, status); + tfDLMTensor->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); int64_t* shape_arr = new int64_t[ndim]; for (int i = 0; i < ndim; i++) { @@ -313,7 +306,7 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { } void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { - DLManagedTensor* tfdlmtensor = TFEHandleToTFDLMTensor(h, status); + DLManagedTensor* tfdlmtensor = TFEHandleToTFDLManagedTensorCtx(h, status); return static_cast(tfdlmtensor); } diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index 7993a6ef8e0..43c205e30c4 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -1,13 +1,25 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + #ifndef TENSORFLOW_C_DLPACK_H_ #define TENSORFLOW_C_DLPACK_H_ #include "tensorflow/c/eager/c_api.h" #include "tensorflow/core/framework/tensor.h" -#ifdef __cplusplus -extern "C" { -#endif - namespace tensorflow { const char* const kDlTensorCapsuleName = "dltensor"; @@ -19,8 +31,5 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status); void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); } // namespace tensorflow -#ifdef __cplusplus -} /* end extern "C" */ -#endif #endif // TENSORFLOW_C_DLPACK_H_ diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index a7545a02904..870693c8190 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1069,9 +1069,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) { "Note that a DLPack tensor may be consumed at most once.", absl::string_view(pycapsule.name())); } - TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack( - static_cast(pycapsule), status.get()); - + TFE_TensorHandle* thandle = + tensorflow::TFE_HandleFromDLPack(pycapsule, status.get()); + PyCapsule_SetName(pycapsule.ptr(), "used_dltensor"); PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());