add python api and test
This commit is contained in:
parent
88d46f6184
commit
883dcc553a
@ -80,6 +80,9 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
|
||||
case TF_DataType::TF_UINT8:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case TF_DataType::TF_INT8:
|
||||
dtype.code = DLDataTypeCode::kDLInt;
|
||||
break;
|
||||
case TF_DataType::TF_INT16:
|
||||
dtype.code = DLDataTypeCode::kDLInt;
|
||||
break;
|
||||
@ -119,6 +122,9 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_QUINT16 is not supported by dlpack");
|
||||
break;
|
||||
case TF_DataType::TF_UINT16:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case TF_DataType::TF_COMPLEX128:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_COMPLEX128 is not supported by dlpack");
|
||||
|
22
tensorflow/python/dlpack/BUILD
Normal file
22
tensorflow/python/dlpack/BUILD
Normal file
@ -0,0 +1,22 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_library(
|
||||
name = "dlpack",
|
||||
srcs = ["dlpack.py"],
|
||||
deps = [
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "dlpack_test",
|
||||
srcs = ["dlpack_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":dlpack",
|
||||
"//tensorflow/python/eager:test",
|
||||
"@absl_py//absl/testing:absltest",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
]
|
||||
)
|
25
tensorflow/python/dlpack/dlpack.py
Normal file
25
tensorflow/python/dlpack/dlpack.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@tf_export("dlpack.to_dlpack")
|
||||
def to_dlpack(tf_tensor):
|
||||
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
|
||||
|
||||
@tf_export("dlpack.from_dlpack")
|
||||
def from_dlpack(dlcapsule):
|
||||
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
|
68
tensorflow/python/dlpack/dlpack_test.py
Normal file
68
tensorflow/python/dlpack/dlpack_test.py
Normal file
@ -0,0 +1,68 @@
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.dlpack.dlpack import from_dlpack, to_dlpack
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
int_dtypes = [
|
||||
np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
|
||||
np.uint64
|
||||
]
|
||||
float_dtypes = [np.float16, np.float32, np.float64]
|
||||
complex_dtypes = [np.complex64, np.complex128]
|
||||
dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16]
|
||||
standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
|
||||
|
||||
|
||||
testcase_shapes = [
|
||||
(),
|
||||
(1,),
|
||||
(2, 3),
|
||||
(2, 0),
|
||||
(0, 7),
|
||||
(4, 1, 2)
|
||||
]
|
||||
|
||||
|
||||
def FormatShapeAndDtype(shape, dtype):
|
||||
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
|
||||
|
||||
|
||||
class DLPackTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
@parameterized.named_parameters({
|
||||
"testcase_name": FormatShapeAndDtype(shape, dtype),
|
||||
"dtype": dtype,
|
||||
"shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes)
|
||||
def testRoundTrip(self, dtype, shape):
|
||||
np.random.seed(42)
|
||||
np_array = np.random.randint(0, 10, shape)
|
||||
tf_tensor = constant_op.constant(np_array, dtype=dtype)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
del tf_tensor # should still work
|
||||
tf_tensor2 = from_dlpack(dlcapsule)
|
||||
self.assertAllClose(np_array, tf_tensor2)
|
||||
|
||||
def testTensorsCanBeConsumedOnceOnly(self):
|
||||
np.random.seed(42)
|
||||
np_array = np.random.randint(0, 10, (2, 3, 4))
|
||||
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
del tf_tensor # should still work
|
||||
tf_tensor2 = from_dlpack(dlcapsule)
|
||||
|
||||
def ConsumeDLPackTensor():
|
||||
from_dlpack(dlcapsule) # Should can be consumed only once
|
||||
self.assertRaisesRegex(Exception,
|
||||
".*a DLPack tensor may be consumed at most once.*",
|
||||
ConsumeDLPackTensor)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
@ -1068,6 +1068,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
|
||||
"Note that a DLPack tensor may be consumed at most once.",
|
||||
absl::string_view(pycapsule.name()));
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
}
|
||||
TFE_TensorHandle* thandle =
|
||||
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
|
||||
|
Loading…
Reference in New Issue
Block a user