Export TFE_TensorHandleToNumpy in pywrap_tensor.h so that in case be used in python binding for unified API.
PiperOrigin-RevId: 330640762 Change-Id: I079573af8e08b907bc062480f51304f1d476b58f
This commit is contained in:
parent
ad6e452065
commit
e19872cc65
@ -28,6 +28,14 @@ void Set_TF_Status_from_Status(TF_Status* tf_status,
|
|||||||
// Returns a "status" from "tf_status".
|
// Returns a "status" from "tf_status".
|
||||||
tensorflow::Status StatusFromTF_Status(const TF_Status* tf_status);
|
tensorflow::Status StatusFromTF_Status(const TF_Status* tf_status);
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
struct TF_StatusDeleter {
|
||||||
|
void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
|
||||||
|
};
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
|
using TF_StatusPtr = std::unique_ptr<TF_Status, internal::TF_StatusDeleter>;
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_TF_STATUS_HELPER_H_
|
#endif // TENSORFLOW_C_TF_STATUS_HELPER_H_
|
||||||
|
@ -80,6 +80,36 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "pywrap_tensor_test_util",
|
||||||
|
testonly = True,
|
||||||
|
srcs = ["pywrap_tensor_test_util.cc"],
|
||||||
|
module_name = "pywrap_tensor_test_util",
|
||||||
|
deps = [
|
||||||
|
":pywrap_tfe_lib",
|
||||||
|
"//tensorflow/c:tf_status_helper",
|
||||||
|
"//tensorflow/c/eager:c_api_test_util",
|
||||||
|
"//tensorflow/python:pybind11_lib",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "pywrap_tensor_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["pywrap_tensor_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
tags = [
|
||||||
|
"no_oss", # TODO(b/168051787): Enable.
|
||||||
|
"no_pip", # TODO(b/168051787): Enable.
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":pywrap_tensor_test_util",
|
||||||
|
":test",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "pywrap_required_hdrs",
|
name = "pywrap_required_hdrs",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -40,9 +40,42 @@ limitations under the License.
|
|||||||
|
|
||||||
// forward declare
|
// forward declare
|
||||||
struct EagerTensor;
|
struct EagerTensor;
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
|
||||||
|
// The two may share underlying storage so changes to one may reflect in the
|
||||||
|
// other.
|
||||||
|
PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
|
||||||
|
if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
|
||||||
|
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||||
|
"Cannot convert a Tensor of dtype resource to a NumPy array.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Safe_TF_TensorPtr tensor = nullptr;
|
||||||
|
Py_BEGIN_ALLOW_THREADS;
|
||||||
|
tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
|
||||||
|
Py_END_ALLOW_THREADS;
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* ret = nullptr;
|
||||||
|
auto cppstatus =
|
||||||
|
tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
|
||||||
|
tensorflow::Set_TF_Status_from_Status(status, cppstatus);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
Py_XDECREF(ret);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
CHECK_NE(ret, nullptr);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
} // namespace tensorflow
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::TFE_TensorHandleToNumpy;
|
||||||
|
|
||||||
// An instance of _EagerTensorProfiler that will receive callbacks about
|
// An instance of _EagerTensorProfiler that will receive callbacks about
|
||||||
// events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
|
// events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
|
||||||
PyObject* eager_tensor_profiler = nullptr;
|
PyObject* eager_tensor_profiler = nullptr;
|
||||||
@ -87,35 +120,6 @@ TFE_Context* GetContextHandle(PyObject* py_context) {
|
|||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
|
|
||||||
// The two may share underlying storage so changes to one may reflect in the
|
|
||||||
// other.
|
|
||||||
PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
|
|
||||||
if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
|
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
|
||||||
"Cannot convert a Tensor of dtype resource to a NumPy array.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::Safe_TF_TensorPtr tensor = nullptr;
|
|
||||||
Py_BEGIN_ALLOW_THREADS;
|
|
||||||
tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
|
|
||||||
Py_END_ALLOW_THREADS;
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
PyObject* ret = nullptr;
|
|
||||||
auto cppstatus =
|
|
||||||
tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
|
|
||||||
tensorflow::Set_TF_Status_from_Status(status, cppstatus);
|
|
||||||
if (!status->status.ok()) {
|
|
||||||
Py_XDECREF(ret);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
CHECK_NE(ret, nullptr);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to convert `v` to a tensorflow::DataType and store it in
|
// Helper function to convert `v` to a tensorflow::DataType and store it in
|
||||||
// `*out`. Returns true on success, false otherwise.
|
// `*out`. Returns true on success, false otherwise.
|
||||||
|
@ -37,6 +37,8 @@ TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
|
|||||||
DataType dtype,
|
DataType dtype,
|
||||||
const char* device_name = nullptr);
|
const char* device_name = nullptr);
|
||||||
|
|
||||||
|
PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
|
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
|
||||||
|
35
tensorflow/python/eager/pywrap_tensor_test.py
Normal file
35
tensorflow/python/eager/pywrap_tensor_test.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for TFE_TensorHandleToNumpy."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tensorflow.python.eager import pywrap_tensor_test_util as util
|
||||||
|
from tensorflow.python.eager import test
|
||||||
|
|
||||||
|
|
||||||
|
class PywrapTensorTest(test.TestCase):
|
||||||
|
|
||||||
|
def testGetScalarOne(self):
|
||||||
|
result = util.get_scalar_one()
|
||||||
|
self.assertIsInstance(result, np.ndarray)
|
||||||
|
self.assertAllEqual(result, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
41
tensorflow/python/eager/pywrap_tensor_test_util.cc
Normal file
41
tensorflow/python/eager/pywrap_tensor_test_util.cc
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/pytypes.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/python/eager/pywrap_tensor.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
|
|
||||||
|
using tensorflow::Pyo;
|
||||||
|
using tensorflow::TF_StatusPtr;
|
||||||
|
using tensorflow::TFE_TensorHandleToNumpy;
|
||||||
|
|
||||||
|
PYBIND11_MODULE(pywrap_tensor_test_util, m) {
|
||||||
|
m.def("get_scalar_one", []() {
|
||||||
|
// Builds a TFE_TensorHandle and then converts to NumPy ndarray
|
||||||
|
// using TFE_TensorHandleToNumpy.
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TF_StatusPtr status(TF_NewStatus());
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||||
|
TFE_TensorHandle* handle = TestScalarTensorHandle(ctx, 1.0f);
|
||||||
|
auto result = Pyo(TFE_TensorHandleToNumpy(handle, status.get()));
|
||||||
|
TFE_DeleteTensorHandle(handle);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
return result;
|
||||||
|
});
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user