fix python sympol export

This commit is contained in:
VoVAllen 2020-02-26 14:48:30 +00:00
parent b90808b7b4
commit 29856e6354
6 changed files with 13 additions and 5 deletions

View File

@ -305,6 +305,9 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
return nullptr; return nullptr;
} }
TF_DataType dtype = TfDataTypeFormDlDataType(dl_tensor->dtype, status); TF_DataType dtype = TfDataTypeFormDlDataType(dl_tensor->dtype, status);
if (!status->status.ok()) {
return nullptr;
}
int num_dims = dl_tensor->ndim; int num_dims = dl_tensor->ndim;
const int64_t* dims = dl_tensor->shape; const int64_t* dims = dl_tensor->shape;
void* data = dl_tensor->data; void* data = dl_tensor->data;

View File

@ -190,6 +190,7 @@ py_library(
"//tensorflow/python/distribute:estimator_training", "//tensorflow/python/distribute:estimator_training",
"//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/dlpack",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:monitoring", "//tensorflow/python/eager:monitoring",
"//tensorflow/python/eager:profiler", "//tensorflow/python/eager:profiler",

View File

@ -170,6 +170,9 @@ from tensorflow.python.debug.lib import check_numerics_callback
from tensorflow.python.debug.lib import dumping_callback from tensorflow.python.debug.lib import dumping_callback
from tensorflow.python.ops import gen_debug_ops from tensorflow.python.ops import gen_debug_ops
# DLPack
from tensorflow.python.dlpack.dlpack import from_dlpack, to_dlpack
# XLA JIT compiler APIs. # XLA JIT compiler APIs.
from tensorflow.python.compiler.xla import jit from tensorflow.python.compiler.xla import jit
from tensorflow.python.compiler.xla import xla from tensorflow.python.compiler.xla import xla

View File

@ -1,4 +1,4 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -23,11 +23,11 @@ from tensorflow.python.util.tf_export import tf_export
# tf.dlpack.to_dlpack/from_dlpack doesn't work. How to fix? # tf.dlpack.to_dlpack/from_dlpack doesn't work. How to fix?
@tf_export("dlpack.to_dlpack") @tf_export("experimental.dlpack.to_dlpack", v1=[])
def to_dlpack(tf_tensor): def to_dlpack(tf_tensor):
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
@tf_export("dlpack.from_dlpack") @tf_export("experimental.dlpack.from_dlpack", v1=[])
def from_dlpack(dlcapsule): def from_dlpack(dlcapsule):
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)

View File

@ -1,4 +1,4 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -81,7 +81,7 @@ class DLPackTest(parameterized.TestCase, test.TestCase):
".*a DLPack tensor may be consumed at most once.*", ".*a DLPack tensor may be consumed at most once.*",
ConsumeDLPackTensor) ConsumeDLPackTensor)
def testUnsupportedType(self): def testUnsupportedTypeToDLPack(self):
def case1(): def case1():
tf_tensor = constant_op.constant( tf_tensor = constant_op.constant(
[[1, 4], [5, 2]], dtype=dtypes.qint16) [[1, 4], [5, 2]], dtype=dtypes.qint16)

View File

@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES = [
"errors/__init__.py", "errors/__init__.py",
"experimental/__init__.py", "experimental/__init__.py",
"experimental/tensorrt/__init__.py", "experimental/tensorrt/__init__.py",
"experimental/dlpack/__init__.py",
"feature_column/__init__.py", "feature_column/__init__.py",
"io/gfile/__init__.py", "io/gfile/__init__.py",
"graph_util/__init__.py", "graph_util/__init__.py",