From a0dbffedac9a540377fa1e98fd0da8f4d6cd5f90 Mon Sep 17 00:00:00 2001
From: Edward Loper <edloper@google.com>
Date: Wed, 12 Aug 2020 13:16:25 -0700
Subject: [PATCH] Add c++ utility function that converts AttrValue protobufs to
 corresponding PyObjects.

PiperOrigin-RevId: 326297504
Change-Id: Id70b00d2708c44459b1b4038a273387b1dc54075
---
 tensorflow/python/framework/op_def_util.cc    | 109 ++++++++++++++++++
 tensorflow/python/framework/op_def_util.h     |  21 ++++
 .../python/framework/op_def_util_pybind.cc    |  20 +++-
 .../python/framework/op_def_util_test.py      |  46 +++++++-
 4 files changed, 190 insertions(+), 6 deletions(-)

diff --git a/tensorflow/python/framework/op_def_util.cc b/tensorflow/python/framework/op_def_util.cc
index c915c494be9..4e1569f190d 100644
--- a/tensorflow/python/framework/op_def_util.cc
+++ b/tensorflow/python/framework/op_def_util.cc
@@ -17,19 +17,28 @@ limitations under the License.
 #include <map>
 
 #include "absl/strings/str_cat.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
 #include "tensorflow/python/util/util.h"
 
 using ::tensorflow::swig::GetRegisteredPyObject;
 
 #if PY_MAJOR_VERSION < 3
+// Python 2.x:
 #define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
+#define PY_STRING_FROMSTRING(x) (PyString_FromString(x))
 #define PY_INT_CHECK(x) (PyInt_Check(x))
 #define PY_INT_TYPE PyInt_Type
+#define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
 #else
+// Python 3.x:
 #define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
+#define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x))
 #define PY_INT_CHECK(x) (PyLong_Check(x))
 #define PY_INT_TYPE PyLong_Type
+#define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
 #endif
 
 namespace tensorflow {
@@ -239,6 +248,64 @@ Safe_PyObjectPtr ConvertAttrOrNull(PyObject* value, AttributeType attr_type) {
   }
 }
 
+// Returns a new reference to Py_True or Py_False depending on b.
+PyObject* PyBool_FromBool(bool b) {
+  PyObject* result = b ? Py_True : Py_False;
+  Py_INCREF(result);
+  return result;
+}
+
+Safe_PyObjectPtr AttrValueListToPyObject(AttrValue::ListValue list) {
+  if (list.s_size()) {
+    Safe_PyObjectPtr result(PyList_New(list.s_size()));
+    for (int i = 0; i < list.s_size(); ++i) {
+      PyList_SET_ITEM(result.get(), i, PY_STRING_FROMSTRING(list.s(i).c_str()));
+    }
+    return result;
+  } else if (list.i_size()) {
+    Safe_PyObjectPtr result(PyList_New(list.i_size()));
+    for (int i = 0; i < list.i_size(); ++i) {
+      PyList_SET_ITEM(result.get(), i, PY_INT_FROM_LONG(list.i(i)));
+    }
+    return result;
+  } else if (list.f_size()) {
+    Safe_PyObjectPtr result(PyList_New(list.f_size()));
+    for (int i = 0; i < list.f_size(); ++i) {
+      PyList_SET_ITEM(result.get(), i, PyFloat_FromDouble(list.f(i)));
+    }
+    return result;
+  } else if (list.b_size()) {
+    Safe_PyObjectPtr result(PyList_New(list.b_size()));
+    for (int i = 0; i < list.b_size(); ++i) {
+      PyList_SET_ITEM(result.get(), i, PyBool_FromBool(list.b(i)));
+    }
+    return result;
+  } else if (list.type_size()) {
+    Safe_PyObjectPtr result(PyList_New(list.type_size()));
+    for (int i = 0; i < list.type_size(); ++i) {
+      Safe_PyObjectPtr item(DataTypeToPyObject(list.type(i)));
+      Py_INCREF(item.get());
+      PyList_SET_ITEM(result.get(), i, item.get());
+    }
+    return result;
+  } else if (list.shape_size()) {
+    Safe_PyObjectPtr result(PyList_New(list.shape_size()));
+    for (int i = 0; i < list.shape_size(); ++i) {
+      Safe_PyObjectPtr item(TensorShapeProtoToPyObject(list.shape(i)));
+      Py_INCREF(item.get());
+      PyList_SET_ITEM(result.get(), i, item.get());
+    }
+    return result;
+  } else if (list.tensor_size() || list.func_size()) {
+    // TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
+    PyErr_SetString(PyExc_TypeError, "Unsupported AttrValue type");
+    return nullptr;
+  } else {
+    // Empty list
+    return Safe_PyObjectPtr(PyList_New(0));
+  }
+}
+
 }  // namespace
 
 AttributeType AttributeTypeFromName(const std::string& type_name) {
@@ -269,4 +336,46 @@ Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
   return result;
 }
 
+Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value) {
+  switch (attr_value.value_case()) {
+    case tensorflow::AttrValue::kS:
+      return Safe_PyObjectPtr(PY_STRING_FROMSTRING(attr_value.s().c_str()));
+    case tensorflow::AttrValue::kI:
+      return Safe_PyObjectPtr(PY_INT_FROM_LONG(attr_value.i()));
+    case tensorflow::AttrValue::kF:
+      return Safe_PyObjectPtr(PyFloat_FromDouble(attr_value.f()));
+    case tensorflow::AttrValue::kB:
+      return Safe_PyObjectPtr(PyBool_FromBool(attr_value.b()));
+    case tensorflow::AttrValue::kType:
+      return DataTypeToPyObject(attr_value.type());
+    case tensorflow::AttrValue::kShape:
+      return TensorShapeProtoToPyObject(attr_value.shape());
+    case tensorflow::AttrValue::kList:
+      return AttrValueListToPyObject(attr_value.list());
+    default:
+      // TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
+      PyErr_SetString(PyExc_ValueError, "Unsupported AttrValue type");
+      return nullptr;
+  }
+}
+
+Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type) {
+  Safe_PyObjectPtr enum_value(PY_INT_FROM_LONG(data_type));
+  return ConvertDTypeFunctor()(enum_value.get());
+}
+
+Safe_PyObjectPtr TensorShapeProtoToPyObject(
+    const TensorShapeProto& tensor_shape) {
+  if (tensor_shape.unknown_rank()) {
+    return ConvertTensorShapeFunctor()(Py_None);
+  } else {
+    Safe_PyObjectPtr dims(PyTuple_New(tensor_shape.dim_size()));
+    for (int i = 0; i < tensor_shape.dim_size(); ++i) {
+      PyTuple_SET_ITEM(dims.get(), i,
+                       PY_INT_FROM_LONG(tensor_shape.dim(i).size()));
+    }
+    return ConvertTensorShapeFunctor()(dims.get());
+  }
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/python/framework/op_def_util.h b/tensorflow/python/framework/op_def_util.h
index ef5e64e68fa..3b35c3ef7ad 100644
--- a/tensorflow/python/framework/op_def_util.h
+++ b/tensorflow/python/framework/op_def_util.h
@@ -15,8 +15,14 @@ limitations under the License.
 #ifndef TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
 #define TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
 
+#include <Python.h>
+
 #include <string>
 
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/python/lib/core/safe_ptr.h"
 
 namespace tensorflow {
@@ -72,6 +78,21 @@ std::string AttributeTypeToName(AttributeType attr_type);
 Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
                                                 AttributeType type);
 
+// Converts a c++ `AttrValue` protobuf message to a Python object; or sets a
+// Python exception and returns nullptr if an error occurs.
+Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value);
+
+// Converts a c++ `DataType` protobuf enum to a Python object; or sets a
+// Python exception and returns nullptr if an error occurs.
+Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type);
+
+// Converts a c++ `TensorShapeProto` message to a Python object; or sets a
+// Python exception and returns nullptr if an error occurs.
+Safe_PyObjectPtr TensorShapeProtoToPyObject(
+    const TensorShapeProto& tensor_shape);
+
+// TODO(edloper): Define TensorProtoToPyObject?
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
diff --git a/tensorflow/python/framework/op_def_util_pybind.cc b/tensorflow/python/framework/op_def_util_pybind.cc
index d13f605b599..a7843322840 100644
--- a/tensorflow/python/framework/op_def_util_pybind.cc
+++ b/tensorflow/python/framework/op_def_util_pybind.cc
@@ -30,12 +30,26 @@ py::handle ConvertAttr(py::handle value, std::string attr_type) {
   return result.release();
 }
 
+py::handle SerializedAttrValueToPyObject(std::string attr_value_string) {
+  tensorflow::AttrValue attr_value;
+  attr_value.ParseFromString(attr_value_string);
+  tensorflow::Safe_PyObjectPtr result =
+      ::tensorflow::AttrValueToPyObject(attr_value);
+  if (!result) {
+    throw py::error_already_set();
+  }
+  Py_INCREF(result.get());
+  return result.release();
+}
+
 }  // namespace
 
-// Expose ConvertPyObjectToAttributeType via Python.  Note: this is done to
-// simplify testing; ConvertPyObjectToAttributeType is expected to be called
-// directly from c++.
+// Expose op_def_util.h functions via Python.
 PYBIND11_MODULE(_op_def_util, m) {
+  // Note: the bindings below are added for testing purposes; but the functions
+  // are expected to be called from c++, not Python.
   m.def("ConvertPyObjectToAttributeType", ConvertAttr, py::arg("value"),
         py::arg("attr_type_enum"));
+  m.def("SerializedAttrValueToPyObject", SerializedAttrValueToPyObject,
+        py::arg("attr_value_string"));
 }
diff --git a/tensorflow/python/framework/op_def_util_test.py b/tensorflow/python/framework/op_def_util_test.py
index 69aaffbf19f..9f2ce61996f 100644
--- a/tensorflow/python/framework/op_def_util_test.py
+++ b/tensorflow/python/framework/op_def_util_test.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """Tests for tensorflow.python.ops.op_def_library."""
 
 from __future__ import absolute_import
@@ -23,6 +22,8 @@ from absl.testing import parameterized
 
 import numpy as np
 
+from google.protobuf import text_format
+from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import tensor_pb2
 from tensorflow.core.framework import types_pb2
 from tensorflow.python import _op_def_util
@@ -63,7 +64,7 @@ class OpDefUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       ("list(int)", (1, 2.3), [1, 2]),
       ("list(float)", (1, 2.3), [1.0, 2.3]),
       ("list(bool)", [True, False], [True, False]),
-  ])
+  ])  # pyformat: disable
   def testConvert(self, attr_type, value, expected):
     result = _op_def_util.ConvertPyObjectToAttributeType(value, attr_type)
 
@@ -93,6 +94,45 @@ class OpDefUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     with self.assertRaisesRegex(TypeError, "Failed to convert value"):
       _op_def_util.ConvertPyObjectToAttributeType(value, attr_type)
 
+  # Test AttrValueToPyObject().  Note: this test also exercises the code in
+  # DataTypeToPyObject() and TensorShapeToPyObject(), since those are used
+  # when the AttrValue contains a DataType or TensorShape.
+  @parameterized.parameters([
+      ("s: 'foo'", "foo"),
+      ("i: 5", 5),
+      ("f: 8", 8.0),
+      ("b: True", True),
+      ("type: DT_INT32", dtypes.int32),
+      ("shape { dim: [{size: 3}, {size: 4}] }",
+       tensor_shape.TensorShape([3, 4])),
+      ("list { }", []),
+      ("list { s: [] }", []),
+      ("list { s: ['a', 'b', 'c'] }", ["a", "b", "c"]),
+      ("list { i: [1, 2, 3] }", [1, 2, 3]),
+      ("list { f: [2.0, 4.0] }", [2.0, 4.0]),
+  ])  # pyformat: disable
+  def testAttrValueToPyObject(self, pbtxt, expected):
+    proto = attr_value_pb2.AttrValue()
+    text_format.Parse(pbtxt, proto)
+    result = _op_def_util.SerializedAttrValueToPyObject(
+        proto.SerializeToString())
+
+    self.assertEqual(expected, result)
+
+  @parameterized.parameters([
+      "",                           # Empty value (oneof not set)
+      "tensor {}",                  # 'TensorProto' not supported (yet).
+      "func {}",                    # 'func' not supported.
+      "placeholder: ''",            # 'placeholder' not supported.
+      "list { tensor [{}] }",       # 'TensorProto' not supported (yet).
+      "list { func [{}] }",         # 'func' not supported.
+  ])  # pyformat: disable
+  def testAttrValueToPyObjectError(self, pbtxt):
+    proto = attr_value_pb2.AttrValue()
+    text_format.Parse(pbtxt, proto)
+    with self.assertRaises((TypeError, ValueError)):
+      _op_def_util.SerializedAttrValueToPyObject(proto.SerializeToString())
+
+
 if __name__ == "__main__":
   googletest.main()
-