diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 6276371bd68..bb898ac1fed 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -305,6 +305,9 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { return nullptr; } TF_DataType dtype = TfDataTypeFormDlDataType(dl_tensor->dtype, status); + if (!status->status.ok()) { + return nullptr; + } int num_dims = dl_tensor->ndim; const int64_t* dims = dl_tensor->shape; void* data = dl_tensor->data; diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 63593f1a428..70ae3aa96f1 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -190,6 +190,7 @@ py_library( "//tensorflow/python/distribute:estimator_training", "//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/dlpack", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:monitoring", "//tensorflow/python/eager:profiler", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 6d88cb566ae..c5a4207b476 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -170,6 +170,9 @@ from tensorflow.python.debug.lib import check_numerics_callback from tensorflow.python.debug.lib import dumping_callback from tensorflow.python.ops import gen_debug_ops +# DLPack +from tensorflow.python.dlpack.dlpack import from_dlpack, to_dlpack + # XLA JIT compiler APIs. from tensorflow.python.compiler.xla import jit from tensorflow.python.compiler.xla import xla diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py index 7a04fca3933..5b278db36ba 100644 --- a/tensorflow/python/dlpack/dlpack.py +++ b/tensorflow/python/dlpack/dlpack.py @@ -1,4 +1,4 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,11 +23,11 @@ from tensorflow.python.util.tf_export import tf_export # tf.dlpack.to_dlpack/from_dlpack doesn't work. How to fix? -@tf_export("dlpack.to_dlpack") +@tf_export("experimental.dlpack.to_dlpack", v1=[]) def to_dlpack(tf_tensor): return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) -@tf_export("dlpack.from_dlpack") +@tf_export("experimental.dlpack.from_dlpack", v1=[]) def from_dlpack(dlcapsule): return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) diff --git a/tensorflow/python/dlpack/dlpack_test.py b/tensorflow/python/dlpack/dlpack_test.py index 8b47c71dc6b..206c2b7d926 100644 --- a/tensorflow/python/dlpack/dlpack_test.py +++ b/tensorflow/python/dlpack/dlpack_test.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -81,7 +81,7 @@ class DLPackTest(parameterized.TestCase, test.TestCase): ".*a DLPack tensor may be consumed at most once.*", ConsumeDLPackTensor) - def testUnsupportedType(self): + def testUnsupportedTypeToDLPack(self): def case1(): tf_tensor = constant_op.constant( [[1, 4], [5, 2]], dtype=dtypes.qint16) diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index 3aab59e50aa..99981a5ce2e 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES = [ "errors/__init__.py", "experimental/__init__.py", "experimental/tensorrt/__init__.py", + "experimental/dlpack/__init__.py", "feature_column/__init__.py", "io/gfile/__init__.py", "graph_util/__init__.py",