fix sanity check
This commit is contained in:
parent
cae654ea99
commit
6ef713fec6
@ -92,8 +92,8 @@ filegroup(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
"tensor_handle_interface.h",
|
|
||||||
"dlpack.h",
|
"dlpack.h",
|
||||||
|
"tensor_handle_interface.h",
|
||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow/core:__pkg__",
|
"//tensorflow/core:__pkg__",
|
||||||
@ -327,7 +327,6 @@ filegroup(
|
|||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "dlpack",
|
name = "dlpack",
|
||||||
srcs = ["dlpack.cc"],
|
srcs = ["dlpack.cc"],
|
||||||
@ -346,7 +345,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
||||||
# right now, remove this public rule when no longer needed (it should be
|
# right now, remove this public rule when no longer needed (it should be
|
||||||
# replaced by TF Lite)
|
# replaced by TF Lite)
|
||||||
|
@ -4,9 +4,11 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
|||||||
py_library(
|
py_library(
|
||||||
name = "dlpack",
|
name = "dlpack",
|
||||||
srcs = ["dlpack.py"],
|
srcs = ["dlpack.py"],
|
||||||
deps = [
|
|
||||||
],
|
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:pywrap_tensorflow",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
@ -18,5 +20,5 @@ cuda_py_test(
|
|||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
"@absl_py//absl/testing:absltest",
|
"@absl_py//absl/testing:absltest",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
@ -12,14 +12,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
"""DLPack modules for Tensorflow"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tfe
|
from tensorflow.python import pywrap_tfe
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@tf_export("dlpack.to_dlpack")
|
@tf_export("dlpack.to_dlpack")
|
||||||
def to_dlpack(tf_tensor):
|
def to_dlpack(tf_tensor):
|
||||||
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
|
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("dlpack.from_dlpack")
|
@tf_export("dlpack.from_dlpack")
|
||||||
def from_dlpack(dlcapsule):
|
def from_dlpack(dlcapsule):
|
||||||
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
|
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
|
||||||
|
@ -1,3 +1,23 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for DLPack functions."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -29,55 +49,55 @@ testcase_shapes = [
|
|||||||
|
|
||||||
|
|
||||||
def FormatShapeAndDtype(shape, dtype):
|
def FormatShapeAndDtype(shape, dtype):
|
||||||
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
|
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
|
||||||
|
|
||||||
|
|
||||||
class DLPackTest(parameterized.TestCase, test.TestCase):
|
class DLPackTest(parameterized.TestCase, test.TestCase):
|
||||||
|
|
||||||
@parameterized.named_parameters({
|
@parameterized.named_parameters({
|
||||||
"testcase_name": FormatShapeAndDtype(shape, dtype),
|
"testcase_name": FormatShapeAndDtype(shape, dtype),
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes)
|
"shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes)
|
||||||
def testRoundTrip(self, dtype, shape):
|
def testRoundTrip(self, dtype, shape):
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
np_array = np.random.randint(0, 10, shape)
|
np_array = np.random.randint(0, 10, shape)
|
||||||
tf_tensor = constant_op.constant(np_array, dtype=dtype)
|
tf_tensor = constant_op.constant(np_array, dtype=dtype)
|
||||||
dlcapsule = to_dlpack(tf_tensor)
|
dlcapsule = to_dlpack(tf_tensor)
|
||||||
del tf_tensor # should still work
|
del tf_tensor # should still work
|
||||||
tf_tensor2 = from_dlpack(dlcapsule)
|
tf_tensor2 = from_dlpack(dlcapsule)
|
||||||
self.assertAllClose(np_array, tf_tensor2)
|
self.assertAllClose(np_array, tf_tensor2)
|
||||||
|
|
||||||
def testTensorsCanBeConsumedOnceOnly(self):
|
def testTensorsCanBeConsumedOnceOnly(self):
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
np_array = np.random.randint(0, 10, (2, 3, 4))
|
np_array = np.random.randint(0, 10, (2, 3, 4))
|
||||||
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
|
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
|
||||||
dlcapsule = to_dlpack(tf_tensor)
|
dlcapsule = to_dlpack(tf_tensor)
|
||||||
del tf_tensor # should still work
|
del tf_tensor # should still work
|
||||||
tf_tensor2 = from_dlpack(dlcapsule)
|
tf_tensor2 = from_dlpack(dlcapsule)
|
||||||
|
|
||||||
def ConsumeDLPackTensor():
|
def ConsumeDLPackTensor():
|
||||||
from_dlpack(dlcapsule) # Should can be consumed only once
|
from_dlpack(dlcapsule) # Should can be consumed only once
|
||||||
self.assertRaisesRegex(Exception,
|
self.assertRaisesRegex(Exception,
|
||||||
".*a DLPack tensor may be consumed at most once.*",
|
".*a DLPack tensor may be consumed at most once.*",
|
||||||
ConsumeDLPackTensor)
|
ConsumeDLPackTensor)
|
||||||
|
|
||||||
def testUnsupportedType(self):
|
def testUnsupportedType(self):
|
||||||
def case1():
|
def case1():
|
||||||
tf_tensor = constant_op.constant(
|
tf_tensor = constant_op.constant(
|
||||||
[[1, 4], [5, 2]], dtype=dtypes.qint16)
|
[[1, 4], [5, 2]], dtype=dtypes.qint16)
|
||||||
dlcapsule = to_dlpack(tf_tensor)
|
dlcapsule = to_dlpack(tf_tensor)
|
||||||
|
|
||||||
def case2():
|
def case2():
|
||||||
tf_tensor = constant_op.constant(
|
tf_tensor = constant_op.constant(
|
||||||
[[1, 4], [5, 2]], dtype=dtypes.complex64)
|
[[1, 4], [5, 2]], dtype=dtypes.complex64)
|
||||||
dlcapsule = to_dlpack(tf_tensor)
|
dlcapsule = to_dlpack(tf_tensor)
|
||||||
|
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
Exception, ".* is not supported by dlpack", case1)
|
Exception, ".* is not supported by dlpack", case1)
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
Exception, ".* is not supported by dlpack", case2)
|
Exception, ".* is not supported by dlpack", case2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -35,8 +35,8 @@ cc_library(
|
|||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
"//tensorflow/c/eager:c_api_experimental",
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
"//tensorflow/c/eager:c_api_internal",
|
"//tensorflow/c/eager:c_api_internal",
|
||||||
"//tensorflow/c/eager:tape",
|
|
||||||
"//tensorflow/c/eager:dlpack",
|
"//tensorflow/c/eager:dlpack",
|
||||||
|
"//tensorflow/c/eager:tape",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -94,6 +94,7 @@ py_library(
|
|||||||
":test",
|
":test",
|
||||||
":wrap_function",
|
":wrap_function",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tensorflow",
|
||||||
|
"//tensorflow/python/dlpack",
|
||||||
"//tensorflow/python/eager/memory_tests:memory_test_util",
|
"//tensorflow/python/eager/memory_tests:memory_test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -159,6 +159,7 @@ filegroup(
|
|||||||
"@com_google_protobuf//:LICENSE",
|
"@com_google_protobuf//:LICENSE",
|
||||||
"@com_googlesource_code_re2//:LICENSE",
|
"@com_googlesource_code_re2//:LICENSE",
|
||||||
"@curl//:COPYING",
|
"@curl//:COPYING",
|
||||||
|
"@dlpack//:LICENSE",
|
||||||
"@double_conversion//:LICENSE",
|
"@double_conversion//:LICENSE",
|
||||||
"@eigen_archive//:COPYING.MPL2",
|
"@eigen_archive//:COPYING.MPL2",
|
||||||
"@enum34_archive//:LICENSE",
|
"@enum34_archive//:LICENSE",
|
||||||
|
Loading…
Reference in New Issue
Block a user