Export TFE_TensorHandleToNumpy in pywrap_tensor.h so that in case be used in python binding for unified API.
PiperOrigin-RevId: 330640762 Change-Id: I079573af8e08b907bc062480f51304f1d476b58f
This commit is contained in:
parent
ad6e452065
commit
e19872cc65
@ -28,6 +28,14 @@ void Set_TF_Status_from_Status(TF_Status* tf_status,
|
||||
// Returns a "status" from "tf_status".
|
||||
tensorflow::Status StatusFromTF_Status(const TF_Status* tf_status);
|
||||
|
||||
namespace internal {
|
||||
struct TF_StatusDeleter {
|
||||
void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using TF_StatusPtr = std::unique_ptr<TF_Status, internal::TF_StatusDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_TF_STATUS_HELPER_H_
|
||||
|
@ -80,6 +80,36 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "pywrap_tensor_test_util",
|
||||
testonly = True,
|
||||
srcs = ["pywrap_tensor_test_util.cc"],
|
||||
module_name = "pywrap_tensor_test_util",
|
||||
deps = [
|
||||
":pywrap_tfe_lib",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/eager:c_api_test_util",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "pywrap_tensor_test",
|
||||
size = "small",
|
||||
srcs = ["pywrap_tensor_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_oss", # TODO(b/168051787): Enable.
|
||||
"no_pip", # TODO(b/168051787): Enable.
|
||||
],
|
||||
deps = [
|
||||
":pywrap_tensor_test_util",
|
||||
":test",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
|
@ -40,9 +40,42 @@ limitations under the License.
|
||||
|
||||
// forward declare
|
||||
struct EagerTensor;
|
||||
namespace tensorflow {
|
||||
|
||||
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
|
||||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
|
||||
if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"Cannot convert a Tensor of dtype resource to a NumPy array.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::Safe_TF_TensorPtr tensor = nullptr;
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
|
||||
Py_END_ALLOW_THREADS;
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* ret = nullptr;
|
||||
auto cppstatus =
|
||||
tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
|
||||
tensorflow::Set_TF_Status_from_Status(status, cppstatus);
|
||||
if (!status->status.ok()) {
|
||||
Py_XDECREF(ret);
|
||||
return nullptr;
|
||||
}
|
||||
CHECK_NE(ret, nullptr);
|
||||
return ret;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
namespace {
|
||||
|
||||
using tensorflow::TFE_TensorHandleToNumpy;
|
||||
|
||||
// An instance of _EagerTensorProfiler that will receive callbacks about
|
||||
// events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
|
||||
PyObject* eager_tensor_profiler = nullptr;
|
||||
@ -87,35 +120,6 @@ TFE_Context* GetContextHandle(PyObject* py_context) {
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
|
||||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
|
||||
if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"Cannot convert a Tensor of dtype resource to a NumPy array.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::Safe_TF_TensorPtr tensor = nullptr;
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
|
||||
Py_END_ALLOW_THREADS;
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* ret = nullptr;
|
||||
auto cppstatus =
|
||||
tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
|
||||
tensorflow::Set_TF_Status_from_Status(status, cppstatus);
|
||||
if (!status->status.ok()) {
|
||||
Py_XDECREF(ret);
|
||||
return nullptr;
|
||||
}
|
||||
CHECK_NE(ret, nullptr);
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Helper function to convert `v` to a tensorflow::DataType and store it in
|
||||
// `*out`. Returns true on success, false otherwise.
|
||||
|
@ -37,6 +37,8 @@ TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
|
||||
DataType dtype,
|
||||
const char* device_name = nullptr);
|
||||
|
||||
PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
|
||||
|
35
tensorflow/python/eager/pywrap_tensor_test.py
Normal file
35
tensorflow/python/eager/pywrap_tensor_test.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TFE_TensorHandleToNumpy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.python.eager import pywrap_tensor_test_util as util
|
||||
from tensorflow.python.eager import test
|
||||
|
||||
|
||||
class PywrapTensorTest(test.TestCase):
|
||||
|
||||
def testGetScalarOne(self):
|
||||
result = util.get_scalar_one()
|
||||
self.assertIsInstance(result, np.ndarray)
|
||||
self.assertAllEqual(result, 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
41
tensorflow/python/eager/pywrap_tensor_test_util.cc
Normal file
41
tensorflow/python/eager/pywrap_tensor_test_util.cc
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/python/eager/pywrap_tensor.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
|
||||
using tensorflow::Pyo;
|
||||
using tensorflow::TF_StatusPtr;
|
||||
using tensorflow::TFE_TensorHandleToNumpy;
|
||||
|
||||
PYBIND11_MODULE(pywrap_tensor_test_util, m) {
|
||||
m.def("get_scalar_one", []() {
|
||||
// Builds a TFE_TensorHandle and then converts to NumPy ndarray
|
||||
// using TFE_TensorHandleToNumpy.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_TensorHandle* handle = TestScalarTensorHandle(ctx, 1.0f);
|
||||
auto result = Pyo(TFE_TensorHandleToNumpy(handle, status.get()));
|
||||
TFE_DeleteTensorHandle(handle);
|
||||
TFE_DeleteContext(ctx);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return result;
|
||||
});
|
||||
}
|
Loading…
Reference in New Issue
Block a user