From 6ef713fec6100c59d1c1f16c825a841ea5c34625 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Fri, 21 Feb 2020 12:03:11 +0000 Subject: [PATCH] fix sanity check --- tensorflow/c/eager/BUILD | 4 +- tensorflow/python/dlpack/BUILD | 10 ++- tensorflow/python/dlpack/dlpack.py | 11 ++- tensorflow/python/dlpack/dlpack_test.py | 100 ++++++++++++++---------- tensorflow/python/eager/BUILD | 3 +- tensorflow/tools/pip_package/BUILD | 1 + 6 files changed, 79 insertions(+), 50 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 224b36a170c..509a6205274 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -92,8 +92,8 @@ filegroup( srcs = [ "c_api_experimental.h", "c_api_internal.h", - "tensor_handle_interface.h", "dlpack.h", + "tensor_handle_interface.h", ], visibility = [ "//tensorflow/core:__pkg__", @@ -327,7 +327,6 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) - cc_library( name = "dlpack", srcs = ["dlpack.cc"], @@ -346,7 +345,6 @@ cc_library( ], ) - # TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime # right now, remove this public rule when no longer needed (it should be # replaced by TF Lite) diff --git a/tensorflow/python/dlpack/BUILD b/tensorflow/python/dlpack/BUILD index c5347c020e5..4e1b3c47070 100644 --- a/tensorflow/python/dlpack/BUILD +++ b/tensorflow/python/dlpack/BUILD @@ -4,9 +4,11 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "dlpack", srcs = ["dlpack.py"], - deps = [ - ], srcs_version = "PY3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/python:pywrap_tensorflow", + ], ) cuda_py_test( @@ -14,9 +16,9 @@ cuda_py_test( srcs = ["dlpack_test.py"], python_version = "PY3", deps = [ - ":dlpack", + ":dlpack", "//tensorflow/python/eager:test", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", - ] + ], ) diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py index 5e364fdc593..601dffad847 100644 --- a/tensorflow/python/dlpack/dlpack.py +++ b/tensorflow/python/dlpack/dlpack.py @@ -12,14 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""DLPack modules for Tensorflow""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function from tensorflow.python import pywrap_tfe from tensorflow.python.util.tf_export import tf_export + @tf_export("dlpack.to_dlpack") def to_dlpack(tf_tensor): - return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) + return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) + @tf_export("dlpack.from_dlpack") def from_dlpack(dlcapsule): - return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) + return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) diff --git a/tensorflow/python/dlpack/dlpack_test.py b/tensorflow/python/dlpack/dlpack_test.py index 8a4f1788446..8b47c71dc6b 100644 --- a/tensorflow/python/dlpack/dlpack_test.py +++ b/tensorflow/python/dlpack/dlpack_test.py @@ -1,3 +1,23 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for DLPack functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -29,55 +49,55 @@ testcase_shapes = [ def FormatShapeAndDtype(shape, dtype): - return "_{}[{}]".format(str(dtype), ",".join(map(str, shape))) + return "_{}[{}]".format(str(dtype), ",".join(map(str, shape))) class DLPackTest(parameterized.TestCase, test.TestCase): - @parameterized.named_parameters({ - "testcase_name": FormatShapeAndDtype(shape, dtype), - "dtype": dtype, - "shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes) - def testRoundTrip(self, dtype, shape): - np.random.seed(42) - np_array = np.random.randint(0, 10, shape) - tf_tensor = constant_op.constant(np_array, dtype=dtype) - dlcapsule = to_dlpack(tf_tensor) - del tf_tensor # should still work - tf_tensor2 = from_dlpack(dlcapsule) - self.assertAllClose(np_array, tf_tensor2) + @parameterized.named_parameters({ + "testcase_name": FormatShapeAndDtype(shape, dtype), + "dtype": dtype, + "shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes) + def testRoundTrip(self, dtype, shape): + np.random.seed(42) + np_array = np.random.randint(0, 10, shape) + tf_tensor = constant_op.constant(np_array, dtype=dtype) + dlcapsule = to_dlpack(tf_tensor) + del tf_tensor # should still work + tf_tensor2 = from_dlpack(dlcapsule) + self.assertAllClose(np_array, tf_tensor2) - def testTensorsCanBeConsumedOnceOnly(self): - np.random.seed(42) - np_array = np.random.randint(0, 10, (2, 3, 4)) - tf_tensor = constant_op.constant(np_array, dtype=np.float32) - dlcapsule = to_dlpack(tf_tensor) - del tf_tensor # should still work - tf_tensor2 = from_dlpack(dlcapsule) + def testTensorsCanBeConsumedOnceOnly(self): + np.random.seed(42) + np_array = np.random.randint(0, 10, (2, 3, 4)) + tf_tensor = constant_op.constant(np_array, dtype=np.float32) + dlcapsule = to_dlpack(tf_tensor) + del tf_tensor # should still work + tf_tensor2 = from_dlpack(dlcapsule) - def ConsumeDLPackTensor(): - from_dlpack(dlcapsule) # Should can be consumed only once - self.assertRaisesRegex(Exception, - ".*a DLPack tensor may be consumed at most once.*", - ConsumeDLPackTensor) + def ConsumeDLPackTensor(): + from_dlpack(dlcapsule) # Should can be consumed only once + self.assertRaisesRegex(Exception, + ".*a DLPack tensor may be consumed at most once.*", + ConsumeDLPackTensor) - def testUnsupportedType(self): - def case1(): - tf_tensor = constant_op.constant( - [[1, 4], [5, 2]], dtype=dtypes.qint16) - dlcapsule = to_dlpack(tf_tensor) + def testUnsupportedType(self): + def case1(): + tf_tensor = constant_op.constant( + [[1, 4], [5, 2]], dtype=dtypes.qint16) + dlcapsule = to_dlpack(tf_tensor) - def case2(): - tf_tensor = constant_op.constant( - [[1, 4], [5, 2]], dtype=dtypes.complex64) - dlcapsule = to_dlpack(tf_tensor) + def case2(): + tf_tensor = constant_op.constant( + [[1, 4], [5, 2]], dtype=dtypes.complex64) + dlcapsule = to_dlpack(tf_tensor) - self.assertRaisesRegex( - Exception, ".* is not supported by dlpack", case1) - self.assertRaisesRegex( - Exception, ".* is not supported by dlpack", case2) + self.assertRaisesRegex( + Exception, ".* is not supported by dlpack", case1) + self.assertRaisesRegex( + Exception, ".* is not supported by dlpack", case2) if __name__ == '__main__': - ops.enable_eager_execution() - test.main() + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 0a792bb2747..d6e1e07b6e2 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -35,8 +35,8 @@ cc_library( "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_internal", - "//tensorflow/c/eager:tape", "//tensorflow/c/eager:dlpack", + "//tensorflow/c/eager:tape", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -94,6 +94,7 @@ py_library( ":test", ":wrap_function", "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/dlpack", "//tensorflow/python/eager/memory_tests:memory_test_util", ], ) diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index f6e17a6e46c..0e2ba08d1a7 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -159,6 +159,7 @@ filegroup( "@com_google_protobuf//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@curl//:COPYING", + "@dlpack//:LICENSE", "@double_conversion//:LICENSE", "@eigen_archive//:COPYING.MPL2", "@enum34_archive//:LICENSE",