Fix multiple vulnerabilities in tf.experimental.dlpack.to_dlpack
.
We have a use after free caused by memory coruption, a segmentation fault caused by memory corruption, several memory leaks and an undefined behavior when taking the reference of a nullptr. PiperOrigin-RevId: 332568894 Change-Id: Ife0fc05e103b35325094ae5d822ee5fdea764572
This commit is contained in:
parent
156872df9b
commit
d8c69c287f
tensorflow
@ -248,21 +248,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
||||
}
|
||||
|
||||
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||
auto tf_dlm_context = GetDlContext(h, status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||
|
||||
auto tf_dlm_type = GetDlDataType(data_type, status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
|
||||
tf_dlm_tensor_ctx->reference = tensor_ref;
|
||||
|
||||
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
||||
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
||||
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
||||
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
|
||||
dlm_tensor->dl_tensor.ctx = tf_dlm_context;
|
||||
int ndim = tensor->dims();
|
||||
dlm_tensor->dl_tensor.ndim = ndim;
|
||||
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
|
||||
dlm_tensor->dl_tensor.data = tf_dlm_data;
|
||||
dlm_tensor->dl_tensor.dtype = tf_dlm_type;
|
||||
|
||||
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||
@ -275,13 +290,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||
}
|
||||
|
||||
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
||||
dlm_tensor->dl_tensor.shape = shape_arr->data();
|
||||
// There are two ways to represent compact row-major data
|
||||
// 1) nullptr indicates tensor is compact and row-majored.
|
||||
// 2) fill in the strides array as the real case for compact row-major data.
|
||||
// Here we choose option 2, since some frameworks didn't handle the strides
|
||||
// argument properly.
|
||||
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
||||
dlm_tensor->dl_tensor.strides = stride_arr->data();
|
||||
|
||||
dlm_tensor->dl_tensor.byte_offset =
|
||||
0; // TF doesn't handle the strides and byte_offsets here
|
||||
return static_cast<void*>(dlm_tensor);
|
||||
|
@ -20,9 +20,11 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
|
||||
from tensorflow.python.dlpack import dlpack
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -95,6 +97,12 @@ class DLPackTest(parameterized.TestCase, test.TestCase):
|
||||
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
|
||||
UnsupportedComplex64)
|
||||
|
||||
def testMustPassTensorArgumentToDLPack(self):
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"The argument to `to_dlpack` must be a TF tensor, not Python object"):
|
||||
dlpack.to_dlpack([1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
|
@ -1129,9 +1129,16 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
// DLPack functions
|
||||
m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
|
||||
PyObject* eager_tensor_pyobject_ptr = o.ptr();
|
||||
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
|
||||
if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The argument to `to_dlpack` must be a TF tensor, not Python object");
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
}
|
||||
|
||||
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
|
||||
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user