From 883dcc553a24c8e5c09c84a5255f20aa18c7f1d7 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Wed, 19 Feb 2020 18:49:03 +0000 Subject: [PATCH] add python api and test --- tensorflow/c/eager/dlpack.cc | 6 +++ tensorflow/python/dlpack/BUILD | 22 ++++++++ tensorflow/python/dlpack/dlpack.py | 25 +++++++++ tensorflow/python/dlpack/dlpack_test.py | 68 +++++++++++++++++++++++++ tensorflow/python/tfe_wrapper.cc | 1 + 5 files changed, 122 insertions(+) create mode 100644 tensorflow/python/dlpack/BUILD create mode 100644 tensorflow/python/dlpack/dlpack.py create mode 100644 tensorflow/python/dlpack/dlpack_test.py diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 40186f39947..e0624ac4ca1 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -80,6 +80,9 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { case TF_DataType::TF_UINT8: dtype.code = DLDataTypeCode::kDLUInt; break; + case TF_DataType::TF_INT8: + dtype.code = DLDataTypeCode::kDLInt; + break; case TF_DataType::TF_INT16: dtype.code = DLDataTypeCode::kDLInt; break; @@ -119,6 +122,9 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { status->status = tensorflow::errors::InvalidArgument( "TF_QUINT16 is not supported by dlpack"); break; + case TF_DataType::TF_UINT16: + dtype.code = DLDataTypeCode::kDLUInt; + break; case TF_DataType::TF_COMPLEX128: status->status = tensorflow::errors::InvalidArgument( "TF_COMPLEX128 is not supported by dlpack"); diff --git a/tensorflow/python/dlpack/BUILD b/tensorflow/python/dlpack/BUILD new file mode 100644 index 00000000000..3c890ec8b8f --- /dev/null +++ b/tensorflow/python/dlpack/BUILD @@ -0,0 +1,22 @@ +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_library( + name = "dlpack", + srcs = ["dlpack.py"], + deps = [ + ], + srcs_version = "PY3", +) + +cuda_py_test( + name = "dlpack_test", + srcs = ["dlpack_test.py"], + python_version = "PY3", + deps = [ + ":dlpack", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:absltest", + "@absl_py//absl/testing:parameterized", + ] +) \ No newline at end of file diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py new file mode 100644 index 00000000000..00be73f3670 --- /dev/null +++ b/tensorflow/python/dlpack/dlpack.py @@ -0,0 +1,25 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from tensorflow.python import pywrap_tfe +from tensorflow.python.util.tf_export import tf_export + +@tf_export("dlpack.to_dlpack") +def to_dlpack(tf_tensor): + return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) + +@tf_export("dlpack.from_dlpack") +def from_dlpack(dlcapsule): + return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) \ No newline at end of file diff --git a/tensorflow/python/dlpack/dlpack_test.py b/tensorflow/python/dlpack/dlpack_test.py new file mode 100644 index 00000000000..8384dfeadea --- /dev/null +++ b/tensorflow/python/dlpack/dlpack_test.py @@ -0,0 +1,68 @@ +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.framework import dtypes +from tensorflow.python.dlpack.dlpack import from_dlpack, to_dlpack + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np + +int_dtypes = [ + np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, + np.uint64 +] +float_dtypes = [np.float16, np.float32, np.float64] +complex_dtypes = [np.complex64, np.complex128] +dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16] +standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] + + +testcase_shapes = [ + (), + (1,), + (2, 3), + (2, 0), + (0, 7), + (4, 1, 2) +] + + +def FormatShapeAndDtype(shape, dtype): + return "_{}[{}]".format(str(dtype), ",".join(map(str, shape))) + + +class DLPackTest(parameterized.TestCase, test.TestCase): + + @parameterized.named_parameters({ + "testcase_name": FormatShapeAndDtype(shape, dtype), + "dtype": dtype, + "shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes) + def testRoundTrip(self, dtype, shape): + np.random.seed(42) + np_array = np.random.randint(0, 10, shape) + tf_tensor = constant_op.constant(np_array, dtype=dtype) + dlcapsule = to_dlpack(tf_tensor) + del tf_tensor # should still work + tf_tensor2 = from_dlpack(dlcapsule) + self.assertAllClose(np_array, tf_tensor2) + + def testTensorsCanBeConsumedOnceOnly(self): + np.random.seed(42) + np_array = np.random.randint(0, 10, (2, 3, 4)) + tf_tensor = constant_op.constant(np_array, dtype=np.float32) + dlcapsule = to_dlpack(tf_tensor) + del tf_tensor # should still work + tf_tensor2 = from_dlpack(dlcapsule) + + def ConsumeDLPackTensor(): + from_dlpack(dlcapsule) # Should can be consumed only once + self.assertRaisesRegex(Exception, + ".*a DLPack tensor may be consumed at most once.*", + ConsumeDLPackTensor) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 870693c8190..0b801b4d51e 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1068,6 +1068,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) { "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " "Note that a DLPack tensor may be consumed at most once.", absl::string_view(pycapsule.name())); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); } TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());