From e19872cc6558eee6a29f5e1be09f4e90e71a4850 Mon Sep 17 00:00:00 2001
From: Saurabh Saxena <srbs@google.com>
Date: Tue, 8 Sep 2020 20:04:17 -0700
Subject: [PATCH] Export TFE_TensorHandleToNumpy in pywrap_tensor.h so that in
 case be used in python binding for unified API.

PiperOrigin-RevId: 330640762
Change-Id: I079573af8e08b907bc062480f51304f1d476b58f
---
 tensorflow/c/tf_status_helper.h               |  8 +++
 tensorflow/python/eager/BUILD                 | 30 +++++++++
 tensorflow/python/eager/pywrap_tensor.cc      | 62 ++++++++++---------
 tensorflow/python/eager/pywrap_tensor.h       |  2 +
 tensorflow/python/eager/pywrap_tensor_test.py | 35 +++++++++++
 .../python/eager/pywrap_tensor_test_util.cc   | 41 ++++++++++++
 6 files changed, 149 insertions(+), 29 deletions(-)
 create mode 100644 tensorflow/python/eager/pywrap_tensor_test.py
 create mode 100644 tensorflow/python/eager/pywrap_tensor_test_util.cc

diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h
index ff8085f1229..a895e608159 100644
--- a/tensorflow/c/tf_status_helper.h
+++ b/tensorflow/c/tf_status_helper.h
@@ -28,6 +28,14 @@ void Set_TF_Status_from_Status(TF_Status* tf_status,
 // Returns a "status" from "tf_status".
 tensorflow::Status StatusFromTF_Status(const TF_Status* tf_status);
 
+namespace internal {
+struct TF_StatusDeleter {
+  void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
+};
+}  // namespace internal
+
+using TF_StatusPtr = std::unique_ptr<TF_Status, internal::TF_StatusDeleter>;
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_C_TF_STATUS_HELPER_H_
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index c9eaef82038..738e6faf68f 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -80,6 +80,36 @@ cc_library(
     ],
 )
 
+tf_python_pybind_extension(
+    name = "pywrap_tensor_test_util",
+    testonly = True,
+    srcs = ["pywrap_tensor_test_util.cc"],
+    module_name = "pywrap_tensor_test_util",
+    deps = [
+        ":pywrap_tfe_lib",
+        "//tensorflow/c:tf_status_helper",
+        "//tensorflow/c/eager:c_api_test_util",
+        "//tensorflow/python:pybind11_lib",
+        "@pybind11",
+    ],
+)
+
+cuda_py_test(
+    name = "pywrap_tensor_test",
+    size = "small",
+    srcs = ["pywrap_tensor_test.py"],
+    python_version = "PY3",
+    tags = [
+        "no_oss",  # TODO(b/168051787): Enable.
+        "no_pip",  # TODO(b/168051787): Enable.
+    ],
+    deps = [
+        ":pywrap_tensor_test_util",
+        ":test",
+        "//third_party/py/numpy",
+    ],
+)
+
 filegroup(
     name = "pywrap_required_hdrs",
     srcs = [
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 0789eab6270..e5c74deaf80 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -40,9 +40,42 @@ limitations under the License.
 
 // forward declare
 struct EagerTensor;
+namespace tensorflow {
 
+// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
+// The two may share underlying storage so changes to one may reflect in the
+// other.
+PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
+  if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
+    TF_SetStatus(status, TF_INVALID_ARGUMENT,
+                 "Cannot convert a Tensor of dtype resource to a NumPy array.");
+    return nullptr;
+  }
+
+  tensorflow::Safe_TF_TensorPtr tensor = nullptr;
+  Py_BEGIN_ALLOW_THREADS;
+  tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
+  Py_END_ALLOW_THREADS;
+  if (!status->status.ok()) {
+    return nullptr;
+  }
+
+  PyObject* ret = nullptr;
+  auto cppstatus =
+      tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
+  tensorflow::Set_TF_Status_from_Status(status, cppstatus);
+  if (!status->status.ok()) {
+    Py_XDECREF(ret);
+    return nullptr;
+  }
+  CHECK_NE(ret, nullptr);
+  return ret;
+}
+}  // namespace tensorflow
 namespace {
 
+using tensorflow::TFE_TensorHandleToNumpy;
+
 // An instance of _EagerTensorProfiler that will receive callbacks about
 // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
 PyObject* eager_tensor_profiler = nullptr;
@@ -87,35 +120,6 @@ TFE_Context* GetContextHandle(PyObject* py_context) {
   return ctx;
 }
 
-// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
-// The two may share underlying storage so changes to one may reflect in the
-// other.
-PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
-  if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
-    TF_SetStatus(status, TF_INVALID_ARGUMENT,
-                 "Cannot convert a Tensor of dtype resource to a NumPy array.");
-    return nullptr;
-  }
-
-  tensorflow::Safe_TF_TensorPtr tensor = nullptr;
-  Py_BEGIN_ALLOW_THREADS;
-  tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
-  Py_END_ALLOW_THREADS;
-  if (!status->status.ok()) {
-    return nullptr;
-  }
-
-  PyObject* ret = nullptr;
-  auto cppstatus =
-      tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
-  tensorflow::Set_TF_Status_from_Status(status, cppstatus);
-  if (!status->status.ok()) {
-    Py_XDECREF(ret);
-    return nullptr;
-  }
-  CHECK_NE(ret, nullptr);
-  return ret;
-}
 
 // Helper function to convert `v` to a tensorflow::DataType and store it in
 // `*out`. Returns true on success, false otherwise.
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
index 4c84b5ce6ea..bc9548ac4ad 100644
--- a/tensorflow/python/eager/pywrap_tensor.h
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -37,6 +37,8 @@ TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
                                        DataType dtype,
                                        const char* device_name = nullptr);
 
+PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
diff --git a/tensorflow/python/eager/pywrap_tensor_test.py b/tensorflow/python/eager/pywrap_tensor_test.py
new file mode 100644
index 00000000000..ee1a3536546
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tensor_test.py
@@ -0,0 +1,35 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TFE_TensorHandleToNumpy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.python.eager import pywrap_tensor_test_util as util
+from tensorflow.python.eager import test
+
+
+class PywrapTensorTest(test.TestCase):
+
+  def testGetScalarOne(self):
+    result = util.get_scalar_one()
+    self.assertIsInstance(result, np.ndarray)
+    self.assertAllEqual(result, 1.0)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/eager/pywrap_tensor_test_util.cc b/tensorflow/python/eager/pywrap_tensor_test_util.cc
new file mode 100644
index 00000000000..21ef8c45e43
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tensor_test_util.cc
@@ -0,0 +1,41 @@
+// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "tensorflow/c/eager/c_api_test_util.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/python/eager/pywrap_tensor.h"
+#include "tensorflow/python/lib/core/pybind11_lib.h"
+
+using tensorflow::Pyo;
+using tensorflow::TF_StatusPtr;
+using tensorflow::TFE_TensorHandleToNumpy;
+
+PYBIND11_MODULE(pywrap_tensor_test_util, m) {
+  m.def("get_scalar_one", []() {
+    // Builds a TFE_TensorHandle and then converts to NumPy ndarray
+    // using TFE_TensorHandleToNumpy.
+    TFE_ContextOptions* opts = TFE_NewContextOptions();
+    TF_StatusPtr status(TF_NewStatus());
+    TFE_Context* ctx = TFE_NewContext(opts, status.get());
+    TFE_TensorHandle* handle = TestScalarTensorHandle(ctx, 1.0f);
+    auto result = Pyo(TFE_TensorHandleToNumpy(handle, status.get()));
+    TFE_DeleteTensorHandle(handle);
+    TFE_DeleteContext(ctx);
+    TFE_DeleteContextOptions(opts);
+    return result;
+  });
+}