fix sanity check
This commit is contained in:
parent
cae654ea99
commit
6ef713fec6
@ -92,8 +92,8 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"tensor_handle_interface.h",
|
||||
"dlpack.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
@ -327,7 +327,6 @@ filegroup(
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
|
||||
cc_library(
|
||||
name = "dlpack",
|
||||
srcs = ["dlpack.cc"],
|
||||
@ -346,7 +345,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
||||
# right now, remove this public rule when no longer needed (it should be
|
||||
# replaced by TF Lite)
|
||||
|
@ -4,9 +4,11 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
py_library(
|
||||
name = "dlpack",
|
||||
srcs = ["dlpack.py"],
|
||||
deps = [
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
@ -14,9 +16,9 @@ cuda_py_test(
|
||||
srcs = ["dlpack_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":dlpack",
|
||||
":dlpack",
|
||||
"//tensorflow/python/eager:test",
|
||||
"@absl_py//absl/testing:absltest",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
@ -12,14 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""DLPack modules for Tensorflow"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("dlpack.to_dlpack")
|
||||
def to_dlpack(tf_tensor):
|
||||
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
|
||||
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
|
||||
|
||||
|
||||
@tf_export("dlpack.from_dlpack")
|
||||
def from_dlpack(dlcapsule):
|
||||
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
|
||||
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
|
||||
|
@ -1,3 +1,23 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for DLPack functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -29,55 +49,55 @@ testcase_shapes = [
|
||||
|
||||
|
||||
def FormatShapeAndDtype(shape, dtype):
|
||||
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
|
||||
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
|
||||
|
||||
|
||||
class DLPackTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
@parameterized.named_parameters({
|
||||
"testcase_name": FormatShapeAndDtype(shape, dtype),
|
||||
"dtype": dtype,
|
||||
"shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes)
|
||||
def testRoundTrip(self, dtype, shape):
|
||||
np.random.seed(42)
|
||||
np_array = np.random.randint(0, 10, shape)
|
||||
tf_tensor = constant_op.constant(np_array, dtype=dtype)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
del tf_tensor # should still work
|
||||
tf_tensor2 = from_dlpack(dlcapsule)
|
||||
self.assertAllClose(np_array, tf_tensor2)
|
||||
@parameterized.named_parameters({
|
||||
"testcase_name": FormatShapeAndDtype(shape, dtype),
|
||||
"dtype": dtype,
|
||||
"shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes)
|
||||
def testRoundTrip(self, dtype, shape):
|
||||
np.random.seed(42)
|
||||
np_array = np.random.randint(0, 10, shape)
|
||||
tf_tensor = constant_op.constant(np_array, dtype=dtype)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
del tf_tensor # should still work
|
||||
tf_tensor2 = from_dlpack(dlcapsule)
|
||||
self.assertAllClose(np_array, tf_tensor2)
|
||||
|
||||
def testTensorsCanBeConsumedOnceOnly(self):
|
||||
np.random.seed(42)
|
||||
np_array = np.random.randint(0, 10, (2, 3, 4))
|
||||
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
del tf_tensor # should still work
|
||||
tf_tensor2 = from_dlpack(dlcapsule)
|
||||
def testTensorsCanBeConsumedOnceOnly(self):
|
||||
np.random.seed(42)
|
||||
np_array = np.random.randint(0, 10, (2, 3, 4))
|
||||
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
del tf_tensor # should still work
|
||||
tf_tensor2 = from_dlpack(dlcapsule)
|
||||
|
||||
def ConsumeDLPackTensor():
|
||||
from_dlpack(dlcapsule) # Should can be consumed only once
|
||||
self.assertRaisesRegex(Exception,
|
||||
".*a DLPack tensor may be consumed at most once.*",
|
||||
ConsumeDLPackTensor)
|
||||
def ConsumeDLPackTensor():
|
||||
from_dlpack(dlcapsule) # Should can be consumed only once
|
||||
self.assertRaisesRegex(Exception,
|
||||
".*a DLPack tensor may be consumed at most once.*",
|
||||
ConsumeDLPackTensor)
|
||||
|
||||
def testUnsupportedType(self):
|
||||
def case1():
|
||||
tf_tensor = constant_op.constant(
|
||||
[[1, 4], [5, 2]], dtype=dtypes.qint16)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
def testUnsupportedType(self):
|
||||
def case1():
|
||||
tf_tensor = constant_op.constant(
|
||||
[[1, 4], [5, 2]], dtype=dtypes.qint16)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
|
||||
def case2():
|
||||
tf_tensor = constant_op.constant(
|
||||
[[1, 4], [5, 2]], dtype=dtypes.complex64)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
def case2():
|
||||
tf_tensor = constant_op.constant(
|
||||
[[1, 4], [5, 2]], dtype=dtypes.complex64)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
Exception, ".* is not supported by dlpack", case1)
|
||||
self.assertRaisesRegex(
|
||||
Exception, ".* is not supported by dlpack", case2)
|
||||
self.assertRaisesRegex(
|
||||
Exception, ".* is not supported by dlpack", case1)
|
||||
self.assertRaisesRegex(
|
||||
Exception, ".* is not supported by dlpack", case2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
@ -35,8 +35,8 @@ cc_library(
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tape",
|
||||
"//tensorflow/c/eager:dlpack",
|
||||
"//tensorflow/c/eager:tape",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -94,6 +94,7 @@ py_library(
|
||||
":test",
|
||||
":wrap_function",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python/dlpack",
|
||||
"//tensorflow/python/eager/memory_tests:memory_test_util",
|
||||
],
|
||||
)
|
||||
|
@ -159,6 +159,7 @@ filegroup(
|
||||
"@com_google_protobuf//:LICENSE",
|
||||
"@com_googlesource_code_re2//:LICENSE",
|
||||
"@curl//:COPYING",
|
||||
"@dlpack//:LICENSE",
|
||||
"@double_conversion//:LICENSE",
|
||||
"@eigen_archive//:COPYING.MPL2",
|
||||
"@enum34_archive//:LICENSE",
|
||||
|
Loading…
Reference in New Issue
Block a user