add python api and test

This commit is contained in:
VoVAllen 2020-02-19 18:49:03 +00:00
parent 88d46f6184
commit 883dcc553a
5 changed files with 122 additions and 0 deletions

View File

@ -80,6 +80,9 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
case TF_DataType::TF_UINT8:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_INT8:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_INT16:
dtype.code = DLDataTypeCode::kDLInt;
break;
@ -119,6 +122,9 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
status->status = tensorflow::errors::InvalidArgument(
"TF_QUINT16 is not supported by dlpack");
break;
case TF_DataType::TF_UINT16:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_COMPLEX128:
status->status = tensorflow::errors::InvalidArgument(
"TF_COMPLEX128 is not supported by dlpack");

View File

@ -0,0 +1,22 @@
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "dlpack",
srcs = ["dlpack.py"],
deps = [
],
srcs_version = "PY3",
)
cuda_py_test(
name = "dlpack_test",
srcs = ["dlpack_test.py"],
python_version = "PY3",
deps = [
":dlpack",
"//tensorflow/python/eager:test",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
]
)

View File

@ -0,0 +1,25 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tensorflow.python import pywrap_tfe
from tensorflow.python.util.tf_export import tf_export
@tf_export("dlpack.to_dlpack")
def to_dlpack(tf_tensor):
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
@tf_export("dlpack.from_dlpack")
def from_dlpack(dlcapsule):
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)

View File

@ -0,0 +1,68 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
from tensorflow.python.framework import dtypes
from tensorflow.python.dlpack.dlpack import from_dlpack, to_dlpack
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
int_dtypes = [
np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
np.uint64
]
float_dtypes = [np.float16, np.float32, np.float64]
complex_dtypes = [np.complex64, np.complex128]
dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16]
standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
testcase_shapes = [
(),
(1,),
(2, 3),
(2, 0),
(0, 7),
(4, 1, 2)
]
def FormatShapeAndDtype(shape, dtype):
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
class DLPackTest(parameterized.TestCase, test.TestCase):
@parameterized.named_parameters({
"testcase_name": FormatShapeAndDtype(shape, dtype),
"dtype": dtype,
"shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes)
def testRoundTrip(self, dtype, shape):
np.random.seed(42)
np_array = np.random.randint(0, 10, shape)
tf_tensor = constant_op.constant(np_array, dtype=dtype)
dlcapsule = to_dlpack(tf_tensor)
del tf_tensor # should still work
tf_tensor2 = from_dlpack(dlcapsule)
self.assertAllClose(np_array, tf_tensor2)
def testTensorsCanBeConsumedOnceOnly(self):
np.random.seed(42)
np_array = np.random.randint(0, 10, (2, 3, 4))
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
dlcapsule = to_dlpack(tf_tensor)
del tf_tensor # should still work
tf_tensor2 = from_dlpack(dlcapsule)
def ConsumeDLPackTensor():
from_dlpack(dlcapsule) # Should can be consumed only once
self.assertRaisesRegex(Exception,
".*a DLPack tensor may be consumed at most once.*",
ConsumeDLPackTensor)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -1068,6 +1068,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
"Note that a DLPack tensor may be consumed at most once.",
absl::string_view(pycapsule.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
TFE_TensorHandle* thandle =
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());