fix sanity check

This commit is contained in:
VoVAllen 2020-02-21 12:03:11 +00:00
parent cae654ea99
commit 6ef713fec6
6 changed files with 79 additions and 50 deletions

View File

@ -92,8 +92,8 @@ filegroup(
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
"tensor_handle_interface.h",
"dlpack.h",
"tensor_handle_interface.h",
],
visibility = [
"//tensorflow/core:__pkg__",
@ -327,7 +327,6 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
@ -346,7 +345,6 @@ cc_library(
],
)
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
# right now, remove this public rule when no longer needed (it should be
# replaced by TF Lite)

View File

@ -4,9 +4,11 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "dlpack",
srcs = ["dlpack.py"],
deps = [
],
srcs_version = "PY3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
],
)
cuda_py_test(
@ -14,9 +16,9 @@ cuda_py_test(
srcs = ["dlpack_test.py"],
python_version = "PY3",
deps = [
":dlpack",
":dlpack",
"//tensorflow/python/eager:test",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
]
],
)

View File

@ -12,14 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DLPack modules for Tensorflow"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tfe
from tensorflow.python.util.tf_export import tf_export
@tf_export("dlpack.to_dlpack")
def to_dlpack(tf_tensor):
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
@tf_export("dlpack.from_dlpack")
def from_dlpack(dlcapsule):
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)

View File

@ -1,3 +1,23 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DLPack functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
@ -29,55 +49,55 @@ testcase_shapes = [
def FormatShapeAndDtype(shape, dtype):
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
class DLPackTest(parameterized.TestCase, test.TestCase):
@parameterized.named_parameters({
"testcase_name": FormatShapeAndDtype(shape, dtype),
"dtype": dtype,
"shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes)
def testRoundTrip(self, dtype, shape):
np.random.seed(42)
np_array = np.random.randint(0, 10, shape)
tf_tensor = constant_op.constant(np_array, dtype=dtype)
dlcapsule = to_dlpack(tf_tensor)
del tf_tensor # should still work
tf_tensor2 = from_dlpack(dlcapsule)
self.assertAllClose(np_array, tf_tensor2)
@parameterized.named_parameters({
"testcase_name": FormatShapeAndDtype(shape, dtype),
"dtype": dtype,
"shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes)
def testRoundTrip(self, dtype, shape):
np.random.seed(42)
np_array = np.random.randint(0, 10, shape)
tf_tensor = constant_op.constant(np_array, dtype=dtype)
dlcapsule = to_dlpack(tf_tensor)
del tf_tensor # should still work
tf_tensor2 = from_dlpack(dlcapsule)
self.assertAllClose(np_array, tf_tensor2)
def testTensorsCanBeConsumedOnceOnly(self):
np.random.seed(42)
np_array = np.random.randint(0, 10, (2, 3, 4))
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
dlcapsule = to_dlpack(tf_tensor)
del tf_tensor # should still work
tf_tensor2 = from_dlpack(dlcapsule)
def testTensorsCanBeConsumedOnceOnly(self):
np.random.seed(42)
np_array = np.random.randint(0, 10, (2, 3, 4))
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
dlcapsule = to_dlpack(tf_tensor)
del tf_tensor # should still work
tf_tensor2 = from_dlpack(dlcapsule)
def ConsumeDLPackTensor():
from_dlpack(dlcapsule) # Should can be consumed only once
self.assertRaisesRegex(Exception,
".*a DLPack tensor may be consumed at most once.*",
ConsumeDLPackTensor)
def ConsumeDLPackTensor():
from_dlpack(dlcapsule) # Should can be consumed only once
self.assertRaisesRegex(Exception,
".*a DLPack tensor may be consumed at most once.*",
ConsumeDLPackTensor)
def testUnsupportedType(self):
def case1():
tf_tensor = constant_op.constant(
[[1, 4], [5, 2]], dtype=dtypes.qint16)
dlcapsule = to_dlpack(tf_tensor)
def testUnsupportedType(self):
def case1():
tf_tensor = constant_op.constant(
[[1, 4], [5, 2]], dtype=dtypes.qint16)
dlcapsule = to_dlpack(tf_tensor)
def case2():
tf_tensor = constant_op.constant(
[[1, 4], [5, 2]], dtype=dtypes.complex64)
dlcapsule = to_dlpack(tf_tensor)
def case2():
tf_tensor = constant_op.constant(
[[1, 4], [5, 2]], dtype=dtypes.complex64)
dlcapsule = to_dlpack(tf_tensor)
self.assertRaisesRegex(
Exception, ".* is not supported by dlpack", case1)
self.assertRaisesRegex(
Exception, ".* is not supported by dlpack", case2)
self.assertRaisesRegex(
Exception, ".* is not supported by dlpack", case1)
self.assertRaisesRegex(
Exception, ".* is not supported by dlpack", case2)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()
ops.enable_eager_execution()
test.main()

View File

@ -35,8 +35,8 @@ cc_library(
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tape",
"//tensorflow/c/eager:dlpack",
"//tensorflow/c/eager:tape",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -94,6 +94,7 @@ py_library(
":test",
":wrap_function",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python/dlpack",
"//tensorflow/python/eager/memory_tests:memory_test_util",
],
)

View File

@ -159,6 +159,7 @@ filegroup(
"@com_google_protobuf//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
"@curl//:COPYING",
"@dlpack//:LICENSE",
"@double_conversion//:LICENSE",
"@eigen_archive//:COPYING.MPL2",
"@enum34_archive//:LICENSE",