Add c++ utility function that converts AttrValue protobufs to corresponding PyObjects.

PiperOrigin-RevId: 326297504
Change-Id: Id70b00d2708c44459b1b4038a273387b1dc54075
This commit is contained in:
Edward Loper 2020-08-12 13:16:25 -07:00 committed by TensorFlower Gardener
parent 479fa040f6
commit a0dbffedac
4 changed files with 190 additions and 6 deletions

View File

@ -17,19 +17,28 @@ limitations under the License.
#include <map>
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
#include "tensorflow/python/util/util.h"
using ::tensorflow::swig::GetRegisteredPyObject;
#if PY_MAJOR_VERSION < 3
// Python 2.x:
#define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
#define PY_STRING_FROMSTRING(x) (PyString_FromString(x))
#define PY_INT_CHECK(x) (PyInt_Check(x))
#define PY_INT_TYPE PyInt_Type
#define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
#else
// Python 3.x:
#define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
#define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x))
#define PY_INT_CHECK(x) (PyLong_Check(x))
#define PY_INT_TYPE PyLong_Type
#define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
#endif
namespace tensorflow {
@ -239,6 +248,64 @@ Safe_PyObjectPtr ConvertAttrOrNull(PyObject* value, AttributeType attr_type) {
}
}
// Returns a new reference to Py_True or Py_False depending on b.
PyObject* PyBool_FromBool(bool b) {
PyObject* result = b ? Py_True : Py_False;
Py_INCREF(result);
return result;
}
Safe_PyObjectPtr AttrValueListToPyObject(AttrValue::ListValue list) {
if (list.s_size()) {
Safe_PyObjectPtr result(PyList_New(list.s_size()));
for (int i = 0; i < list.s_size(); ++i) {
PyList_SET_ITEM(result.get(), i, PY_STRING_FROMSTRING(list.s(i).c_str()));
}
return result;
} else if (list.i_size()) {
Safe_PyObjectPtr result(PyList_New(list.i_size()));
for (int i = 0; i < list.i_size(); ++i) {
PyList_SET_ITEM(result.get(), i, PY_INT_FROM_LONG(list.i(i)));
}
return result;
} else if (list.f_size()) {
Safe_PyObjectPtr result(PyList_New(list.f_size()));
for (int i = 0; i < list.f_size(); ++i) {
PyList_SET_ITEM(result.get(), i, PyFloat_FromDouble(list.f(i)));
}
return result;
} else if (list.b_size()) {
Safe_PyObjectPtr result(PyList_New(list.b_size()));
for (int i = 0; i < list.b_size(); ++i) {
PyList_SET_ITEM(result.get(), i, PyBool_FromBool(list.b(i)));
}
return result;
} else if (list.type_size()) {
Safe_PyObjectPtr result(PyList_New(list.type_size()));
for (int i = 0; i < list.type_size(); ++i) {
Safe_PyObjectPtr item(DataTypeToPyObject(list.type(i)));
Py_INCREF(item.get());
PyList_SET_ITEM(result.get(), i, item.get());
}
return result;
} else if (list.shape_size()) {
Safe_PyObjectPtr result(PyList_New(list.shape_size()));
for (int i = 0; i < list.shape_size(); ++i) {
Safe_PyObjectPtr item(TensorShapeProtoToPyObject(list.shape(i)));
Py_INCREF(item.get());
PyList_SET_ITEM(result.get(), i, item.get());
}
return result;
} else if (list.tensor_size() || list.func_size()) {
// TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
PyErr_SetString(PyExc_TypeError, "Unsupported AttrValue type");
return nullptr;
} else {
// Empty list
return Safe_PyObjectPtr(PyList_New(0));
}
}
} // namespace
AttributeType AttributeTypeFromName(const std::string& type_name) {
@ -269,4 +336,46 @@ Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
return result;
}
Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value) {
switch (attr_value.value_case()) {
case tensorflow::AttrValue::kS:
return Safe_PyObjectPtr(PY_STRING_FROMSTRING(attr_value.s().c_str()));
case tensorflow::AttrValue::kI:
return Safe_PyObjectPtr(PY_INT_FROM_LONG(attr_value.i()));
case tensorflow::AttrValue::kF:
return Safe_PyObjectPtr(PyFloat_FromDouble(attr_value.f()));
case tensorflow::AttrValue::kB:
return Safe_PyObjectPtr(PyBool_FromBool(attr_value.b()));
case tensorflow::AttrValue::kType:
return DataTypeToPyObject(attr_value.type());
case tensorflow::AttrValue::kShape:
return TensorShapeProtoToPyObject(attr_value.shape());
case tensorflow::AttrValue::kList:
return AttrValueListToPyObject(attr_value.list());
default:
// TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
PyErr_SetString(PyExc_ValueError, "Unsupported AttrValue type");
return nullptr;
}
}
Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type) {
Safe_PyObjectPtr enum_value(PY_INT_FROM_LONG(data_type));
return ConvertDTypeFunctor()(enum_value.get());
}
Safe_PyObjectPtr TensorShapeProtoToPyObject(
const TensorShapeProto& tensor_shape) {
if (tensor_shape.unknown_rank()) {
return ConvertTensorShapeFunctor()(Py_None);
} else {
Safe_PyObjectPtr dims(PyTuple_New(tensor_shape.dim_size()));
for (int i = 0; i < tensor_shape.dim_size(); ++i) {
PyTuple_SET_ITEM(dims.get(), i,
PY_INT_FROM_LONG(tensor_shape.dim(i).size()));
}
return ConvertTensorShapeFunctor()(dims.get());
}
}
} // namespace tensorflow

View File

@ -15,8 +15,14 @@ limitations under the License.
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
#define TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
#include <Python.h>
#include <string>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
namespace tensorflow {
@ -72,6 +78,21 @@ std::string AttributeTypeToName(AttributeType attr_type);
Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
AttributeType type);
// Converts a c++ `AttrValue` protobuf message to a Python object; or sets a
// Python exception and returns nullptr if an error occurs.
Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value);
// Converts a c++ `DataType` protobuf enum to a Python object; or sets a
// Python exception and returns nullptr if an error occurs.
Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type);
// Converts a c++ `TensorShapeProto` message to a Python object; or sets a
// Python exception and returns nullptr if an error occurs.
Safe_PyObjectPtr TensorShapeProtoToPyObject(
const TensorShapeProto& tensor_shape);
// TODO(edloper): Define TensorProtoToPyObject?
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_

View File

@ -30,12 +30,26 @@ py::handle ConvertAttr(py::handle value, std::string attr_type) {
return result.release();
}
py::handle SerializedAttrValueToPyObject(std::string attr_value_string) {
tensorflow::AttrValue attr_value;
attr_value.ParseFromString(attr_value_string);
tensorflow::Safe_PyObjectPtr result =
::tensorflow::AttrValueToPyObject(attr_value);
if (!result) {
throw py::error_already_set();
}
Py_INCREF(result.get());
return result.release();
}
} // namespace
// Expose ConvertPyObjectToAttributeType via Python. Note: this is done to
// simplify testing; ConvertPyObjectToAttributeType is expected to be called
// directly from c++.
// Expose op_def_util.h functions via Python.
PYBIND11_MODULE(_op_def_util, m) {
// Note: the bindings below are added for testing purposes; but the functions
// are expected to be called from c++, not Python.
m.def("ConvertPyObjectToAttributeType", ConvertAttr, py::arg("value"),
py::arg("attr_type_enum"));
m.def("SerializedAttrValueToPyObject", SerializedAttrValueToPyObject,
py::arg("attr_value_string"));
}

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.ops.op_def_library."""
from __future__ import absolute_import
@ -23,6 +22,8 @@ from absl.testing import parameterized
import numpy as np
from google.protobuf import text_format
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python import _op_def_util
@ -63,7 +64,7 @@ class OpDefUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
("list(int)", (1, 2.3), [1, 2]),
("list(float)", (1, 2.3), [1.0, 2.3]),
("list(bool)", [True, False], [True, False]),
])
]) # pyformat: disable
def testConvert(self, attr_type, value, expected):
result = _op_def_util.ConvertPyObjectToAttributeType(value, attr_type)
@ -93,6 +94,45 @@ class OpDefUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with self.assertRaisesRegex(TypeError, "Failed to convert value"):
_op_def_util.ConvertPyObjectToAttributeType(value, attr_type)
# Test AttrValueToPyObject(). Note: this test also exercises the code in
# DataTypeToPyObject() and TensorShapeToPyObject(), since those are used
# when the AttrValue contains a DataType or TensorShape.
@parameterized.parameters([
("s: 'foo'", "foo"),
("i: 5", 5),
("f: 8", 8.0),
("b: True", True),
("type: DT_INT32", dtypes.int32),
("shape { dim: [{size: 3}, {size: 4}] }",
tensor_shape.TensorShape([3, 4])),
("list { }", []),
("list { s: [] }", []),
("list { s: ['a', 'b', 'c'] }", ["a", "b", "c"]),
("list { i: [1, 2, 3] }", [1, 2, 3]),
("list { f: [2.0, 4.0] }", [2.0, 4.0]),
]) # pyformat: disable
def testAttrValueToPyObject(self, pbtxt, expected):
proto = attr_value_pb2.AttrValue()
text_format.Parse(pbtxt, proto)
result = _op_def_util.SerializedAttrValueToPyObject(
proto.SerializeToString())
self.assertEqual(expected, result)
@parameterized.parameters([
"", # Empty value (oneof not set)
"tensor {}", # 'TensorProto' not supported (yet).
"func {}", # 'func' not supported.
"placeholder: ''", # 'placeholder' not supported.
"list { tensor [{}] }", # 'TensorProto' not supported (yet).
"list { func [{}] }", # 'func' not supported.
]) # pyformat: disable
def testAttrValueToPyObjectError(self, pbtxt):
proto = attr_value_pb2.AttrValue()
text_format.Parse(pbtxt, proto)
with self.assertRaises((TypeError, ValueError)):
_op_def_util.SerializedAttrValueToPyObject(proto.SerializeToString())
if __name__ == "__main__":
googletest.main()