diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h index ff8085f1229..a895e608159 100644 --- a/tensorflow/c/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -28,6 +28,14 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, // Returns a "status" from "tf_status". tensorflow::Status StatusFromTF_Status(const TF_Status* tf_status); +namespace internal { +struct TF_StatusDeleter { + void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); } +}; +} // namespace internal + +using TF_StatusPtr = std::unique_ptr; + } // namespace tensorflow #endif // TENSORFLOW_C_TF_STATUS_HELPER_H_ diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index c9eaef82038..738e6faf68f 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -80,6 +80,36 @@ cc_library( ], ) +tf_python_pybind_extension( + name = "pywrap_tensor_test_util", + testonly = True, + srcs = ["pywrap_tensor_test_util.cc"], + module_name = "pywrap_tensor_test_util", + deps = [ + ":pywrap_tfe_lib", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/eager:c_api_test_util", + "//tensorflow/python:pybind11_lib", + "@pybind11", + ], +) + +cuda_py_test( + name = "pywrap_tensor_test", + size = "small", + srcs = ["pywrap_tensor_test.py"], + python_version = "PY3", + tags = [ + "no_oss", # TODO(b/168051787): Enable. + "no_pip", # TODO(b/168051787): Enable. + ], + deps = [ + ":pywrap_tensor_test_util", + ":test", + "//third_party/py/numpy", + ], +) + filegroup( name = "pywrap_required_hdrs", srcs = [ diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 0789eab6270..e5c74deaf80 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -40,9 +40,42 @@ limitations under the License. // forward declare struct EagerTensor; +namespace tensorflow { +// Convert a TFE_TensorHandle to a Python numpy.ndarray object. +// The two may share underlying storage so changes to one may reflect in the +// other. +PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) { + if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "Cannot convert a Tensor of dtype resource to a NumPy array."); + return nullptr; + } + + tensorflow::Safe_TF_TensorPtr tensor = nullptr; + Py_BEGIN_ALLOW_THREADS; + tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status)); + Py_END_ALLOW_THREADS; + if (!status->status.ok()) { + return nullptr; + } + + PyObject* ret = nullptr; + auto cppstatus = + tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret); + tensorflow::Set_TF_Status_from_Status(status, cppstatus); + if (!status->status.ok()) { + Py_XDECREF(ret); + return nullptr; + } + CHECK_NE(ret, nullptr); + return ret; +} +} // namespace tensorflow namespace { +using tensorflow::TFE_TensorHandleToNumpy; + // An instance of _EagerTensorProfiler that will receive callbacks about // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all. PyObject* eager_tensor_profiler = nullptr; @@ -87,35 +120,6 @@ TFE_Context* GetContextHandle(PyObject* py_context) { return ctx; } -// Convert a TFE_TensorHandle to a Python numpy.ndarray object. -// The two may share underlying storage so changes to one may reflect in the -// other. -PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) { - if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - "Cannot convert a Tensor of dtype resource to a NumPy array."); - return nullptr; - } - - tensorflow::Safe_TF_TensorPtr tensor = nullptr; - Py_BEGIN_ALLOW_THREADS; - tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status)); - Py_END_ALLOW_THREADS; - if (!status->status.ok()) { - return nullptr; - } - - PyObject* ret = nullptr; - auto cppstatus = - tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret); - tensorflow::Set_TF_Status_from_Status(status, cppstatus); - if (!status->status.ok()) { - Py_XDECREF(ret); - return nullptr; - } - CHECK_NE(ret, nullptr); - return ret; -} // Helper function to convert `v` to a tensorflow::DataType and store it in // `*out`. Returns true on success, false otherwise. diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h index 4c84b5ce6ea..bc9548ac4ad 100644 --- a/tensorflow/python/eager/pywrap_tensor.h +++ b/tensorflow/python/eager/pywrap_tensor.h @@ -37,6 +37,8 @@ TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value, DataType dtype, const char* device_name = nullptr); +PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status); + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_ diff --git a/tensorflow/python/eager/pywrap_tensor_test.py b/tensorflow/python/eager/pywrap_tensor_test.py new file mode 100644 index 00000000000..ee1a3536546 --- /dev/null +++ b/tensorflow/python/eager/pywrap_tensor_test.py @@ -0,0 +1,35 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TFE_TensorHandleToNumpy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.eager import pywrap_tensor_test_util as util +from tensorflow.python.eager import test + + +class PywrapTensorTest(test.TestCase): + + def testGetScalarOne(self): + result = util.get_scalar_one() + self.assertIsInstance(result, np.ndarray) + self.assertAllEqual(result, 1.0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/eager/pywrap_tensor_test_util.cc b/tensorflow/python/eager/pywrap_tensor_test_util.cc new file mode 100644 index 00000000000..21ef8c45e43 --- /dev/null +++ b/tensorflow/python/eager/pywrap_tensor_test_util.cc @@ -0,0 +1,41 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/python/eager/pywrap_tensor.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" + +using tensorflow::Pyo; +using tensorflow::TF_StatusPtr; +using tensorflow::TFE_TensorHandleToNumpy; + +PYBIND11_MODULE(pywrap_tensor_test_util, m) { + m.def("get_scalar_one", []() { + // Builds a TFE_TensorHandle and then converts to NumPy ndarray + // using TFE_TensorHandleToNumpy. + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_StatusPtr status(TF_NewStatus()); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_TensorHandle* handle = TestScalarTensorHandle(ctx, 1.0f); + auto result = Pyo(TFE_TensorHandleToNumpy(handle, status.get())); + TFE_DeleteTensorHandle(handle); + TFE_DeleteContext(ctx); + TFE_DeleteContextOptions(opts); + return result; + }); +}