Export TFE_TensorHandleToNumpy in pywrap_tensor.h so that in case be used in python binding for unified API.

PiperOrigin-RevId: 330640762
Change-Id: I079573af8e08b907bc062480f51304f1d476b58f
This commit is contained in:
Saurabh Saxena 2020-09-08 20:04:17 -07:00 committed by TensorFlower Gardener
parent ad6e452065
commit e19872cc65
6 changed files with 149 additions and 29 deletions

View File

@ -28,6 +28,14 @@ void Set_TF_Status_from_Status(TF_Status* tf_status,
// Returns a "status" from "tf_status".
tensorflow::Status StatusFromTF_Status(const TF_Status* tf_status);
namespace internal {
struct TF_StatusDeleter {
void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
};
} // namespace internal
using TF_StatusPtr = std::unique_ptr<TF_Status, internal::TF_StatusDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_TF_STATUS_HELPER_H_

View File

@ -80,6 +80,36 @@ cc_library(
],
)
tf_python_pybind_extension(
name = "pywrap_tensor_test_util",
testonly = True,
srcs = ["pywrap_tensor_test_util.cc"],
module_name = "pywrap_tensor_test_util",
deps = [
":pywrap_tfe_lib",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/python:pybind11_lib",
"@pybind11",
],
)
cuda_py_test(
name = "pywrap_tensor_test",
size = "small",
srcs = ["pywrap_tensor_test.py"],
python_version = "PY3",
tags = [
"no_oss", # TODO(b/168051787): Enable.
"no_pip", # TODO(b/168051787): Enable.
],
deps = [
":pywrap_tensor_test_util",
":test",
"//third_party/py/numpy",
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [

View File

@ -40,9 +40,42 @@ limitations under the License.
// forward declare
struct EagerTensor;
namespace tensorflow {
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
// The two may share underlying storage so changes to one may reflect in the
// other.
PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"Cannot convert a Tensor of dtype resource to a NumPy array.");
return nullptr;
}
tensorflow::Safe_TF_TensorPtr tensor = nullptr;
Py_BEGIN_ALLOW_THREADS;
tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
Py_END_ALLOW_THREADS;
if (!status->status.ok()) {
return nullptr;
}
PyObject* ret = nullptr;
auto cppstatus =
tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
tensorflow::Set_TF_Status_from_Status(status, cppstatus);
if (!status->status.ok()) {
Py_XDECREF(ret);
return nullptr;
}
CHECK_NE(ret, nullptr);
return ret;
}
} // namespace tensorflow
namespace {
using tensorflow::TFE_TensorHandleToNumpy;
// An instance of _EagerTensorProfiler that will receive callbacks about
// events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
PyObject* eager_tensor_profiler = nullptr;
@ -87,35 +120,6 @@ TFE_Context* GetContextHandle(PyObject* py_context) {
return ctx;
}
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
// The two may share underlying storage so changes to one may reflect in the
// other.
PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"Cannot convert a Tensor of dtype resource to a NumPy array.");
return nullptr;
}
tensorflow::Safe_TF_TensorPtr tensor = nullptr;
Py_BEGIN_ALLOW_THREADS;
tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
Py_END_ALLOW_THREADS;
if (!status->status.ok()) {
return nullptr;
}
PyObject* ret = nullptr;
auto cppstatus =
tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
tensorflow::Set_TF_Status_from_Status(status, cppstatus);
if (!status->status.ok()) {
Py_XDECREF(ret);
return nullptr;
}
CHECK_NE(ret, nullptr);
return ret;
}
// Helper function to convert `v` to a tensorflow::DataType and store it in
// `*out`. Returns true on success, false otherwise.

View File

@ -37,6 +37,8 @@ TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
DataType dtype,
const char* device_name = nullptr);
PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_

View File

@ -0,0 +1,35 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TFE_TensorHandleToNumpy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.eager import pywrap_tensor_test_util as util
from tensorflow.python.eager import test
class PywrapTensorTest(test.TestCase):
def testGetScalarOne(self):
result = util.get_scalar_one()
self.assertIsInstance(result, np.ndarray)
self.assertAllEqual(result, 1.0)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,41 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
using tensorflow::Pyo;
using tensorflow::TF_StatusPtr;
using tensorflow::TFE_TensorHandleToNumpy;
PYBIND11_MODULE(pywrap_tensor_test_util, m) {
m.def("get_scalar_one", []() {
// Builds a TFE_TensorHandle and then converts to NumPy ndarray
// using TFE_TensorHandleToNumpy.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_StatusPtr status(TF_NewStatus());
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_TensorHandle* handle = TestScalarTensorHandle(ctx, 1.0f);
auto result = Pyo(TFE_TensorHandleToNumpy(handle, status.get()));
TFE_DeleteTensorHandle(handle);
TFE_DeleteContext(ctx);
TFE_DeleteContextOptions(opts);
return result;
});
}