Add c++ utility function that converts AttrValue protobufs to corresponding PyObjects.
PiperOrigin-RevId: 326297504 Change-Id: Id70b00d2708c44459b1b4038a273387b1dc54075
This commit is contained in:
parent
479fa040f6
commit
a0dbffedac
@ -17,19 +17,28 @@ limitations under the License.
|
||||
#include <map>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
#include "tensorflow/python/util/util.h"
|
||||
|
||||
using ::tensorflow::swig::GetRegisteredPyObject;
|
||||
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
// Python 2.x:
|
||||
#define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
|
||||
#define PY_STRING_FROMSTRING(x) (PyString_FromString(x))
|
||||
#define PY_INT_CHECK(x) (PyInt_Check(x))
|
||||
#define PY_INT_TYPE PyInt_Type
|
||||
#define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
|
||||
#else
|
||||
// Python 3.x:
|
||||
#define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
|
||||
#define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x))
|
||||
#define PY_INT_CHECK(x) (PyLong_Check(x))
|
||||
#define PY_INT_TYPE PyLong_Type
|
||||
#define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
@ -239,6 +248,64 @@ Safe_PyObjectPtr ConvertAttrOrNull(PyObject* value, AttributeType attr_type) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a new reference to Py_True or Py_False depending on b.
|
||||
PyObject* PyBool_FromBool(bool b) {
|
||||
PyObject* result = b ? Py_True : Py_False;
|
||||
Py_INCREF(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
Safe_PyObjectPtr AttrValueListToPyObject(AttrValue::ListValue list) {
|
||||
if (list.s_size()) {
|
||||
Safe_PyObjectPtr result(PyList_New(list.s_size()));
|
||||
for (int i = 0; i < list.s_size(); ++i) {
|
||||
PyList_SET_ITEM(result.get(), i, PY_STRING_FROMSTRING(list.s(i).c_str()));
|
||||
}
|
||||
return result;
|
||||
} else if (list.i_size()) {
|
||||
Safe_PyObjectPtr result(PyList_New(list.i_size()));
|
||||
for (int i = 0; i < list.i_size(); ++i) {
|
||||
PyList_SET_ITEM(result.get(), i, PY_INT_FROM_LONG(list.i(i)));
|
||||
}
|
||||
return result;
|
||||
} else if (list.f_size()) {
|
||||
Safe_PyObjectPtr result(PyList_New(list.f_size()));
|
||||
for (int i = 0; i < list.f_size(); ++i) {
|
||||
PyList_SET_ITEM(result.get(), i, PyFloat_FromDouble(list.f(i)));
|
||||
}
|
||||
return result;
|
||||
} else if (list.b_size()) {
|
||||
Safe_PyObjectPtr result(PyList_New(list.b_size()));
|
||||
for (int i = 0; i < list.b_size(); ++i) {
|
||||
PyList_SET_ITEM(result.get(), i, PyBool_FromBool(list.b(i)));
|
||||
}
|
||||
return result;
|
||||
} else if (list.type_size()) {
|
||||
Safe_PyObjectPtr result(PyList_New(list.type_size()));
|
||||
for (int i = 0; i < list.type_size(); ++i) {
|
||||
Safe_PyObjectPtr item(DataTypeToPyObject(list.type(i)));
|
||||
Py_INCREF(item.get());
|
||||
PyList_SET_ITEM(result.get(), i, item.get());
|
||||
}
|
||||
return result;
|
||||
} else if (list.shape_size()) {
|
||||
Safe_PyObjectPtr result(PyList_New(list.shape_size()));
|
||||
for (int i = 0; i < list.shape_size(); ++i) {
|
||||
Safe_PyObjectPtr item(TensorShapeProtoToPyObject(list.shape(i)));
|
||||
Py_INCREF(item.get());
|
||||
PyList_SET_ITEM(result.get(), i, item.get());
|
||||
}
|
||||
return result;
|
||||
} else if (list.tensor_size() || list.func_size()) {
|
||||
// TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
|
||||
PyErr_SetString(PyExc_TypeError, "Unsupported AttrValue type");
|
||||
return nullptr;
|
||||
} else {
|
||||
// Empty list
|
||||
return Safe_PyObjectPtr(PyList_New(0));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AttributeType AttributeTypeFromName(const std::string& type_name) {
|
||||
@ -269,4 +336,46 @@ Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
|
||||
return result;
|
||||
}
|
||||
|
||||
Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value) {
|
||||
switch (attr_value.value_case()) {
|
||||
case tensorflow::AttrValue::kS:
|
||||
return Safe_PyObjectPtr(PY_STRING_FROMSTRING(attr_value.s().c_str()));
|
||||
case tensorflow::AttrValue::kI:
|
||||
return Safe_PyObjectPtr(PY_INT_FROM_LONG(attr_value.i()));
|
||||
case tensorflow::AttrValue::kF:
|
||||
return Safe_PyObjectPtr(PyFloat_FromDouble(attr_value.f()));
|
||||
case tensorflow::AttrValue::kB:
|
||||
return Safe_PyObjectPtr(PyBool_FromBool(attr_value.b()));
|
||||
case tensorflow::AttrValue::kType:
|
||||
return DataTypeToPyObject(attr_value.type());
|
||||
case tensorflow::AttrValue::kShape:
|
||||
return TensorShapeProtoToPyObject(attr_value.shape());
|
||||
case tensorflow::AttrValue::kList:
|
||||
return AttrValueListToPyObject(attr_value.list());
|
||||
default:
|
||||
// TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
|
||||
PyErr_SetString(PyExc_ValueError, "Unsupported AttrValue type");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type) {
|
||||
Safe_PyObjectPtr enum_value(PY_INT_FROM_LONG(data_type));
|
||||
return ConvertDTypeFunctor()(enum_value.get());
|
||||
}
|
||||
|
||||
Safe_PyObjectPtr TensorShapeProtoToPyObject(
|
||||
const TensorShapeProto& tensor_shape) {
|
||||
if (tensor_shape.unknown_rank()) {
|
||||
return ConvertTensorShapeFunctor()(Py_None);
|
||||
} else {
|
||||
Safe_PyObjectPtr dims(PyTuple_New(tensor_shape.dim_size()));
|
||||
for (int i = 0; i < tensor_shape.dim_size(); ++i) {
|
||||
PyTuple_SET_ITEM(dims.get(), i,
|
||||
PY_INT_FROM_LONG(tensor_shape.dim(i).size()));
|
||||
}
|
||||
return ConvertTensorShapeFunctor()(dims.get());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,8 +15,14 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
|
||||
#define TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -72,6 +78,21 @@ std::string AttributeTypeToName(AttributeType attr_type);
|
||||
Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
|
||||
AttributeType type);
|
||||
|
||||
// Converts a c++ `AttrValue` protobuf message to a Python object; or sets a
|
||||
// Python exception and returns nullptr if an error occurs.
|
||||
Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value);
|
||||
|
||||
// Converts a c++ `DataType` protobuf enum to a Python object; or sets a
|
||||
// Python exception and returns nullptr if an error occurs.
|
||||
Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type);
|
||||
|
||||
// Converts a c++ `TensorShapeProto` message to a Python object; or sets a
|
||||
// Python exception and returns nullptr if an error occurs.
|
||||
Safe_PyObjectPtr TensorShapeProtoToPyObject(
|
||||
const TensorShapeProto& tensor_shape);
|
||||
|
||||
// TODO(edloper): Define TensorProtoToPyObject?
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
|
||||
|
@ -30,12 +30,26 @@ py::handle ConvertAttr(py::handle value, std::string attr_type) {
|
||||
return result.release();
|
||||
}
|
||||
|
||||
py::handle SerializedAttrValueToPyObject(std::string attr_value_string) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
attr_value.ParseFromString(attr_value_string);
|
||||
tensorflow::Safe_PyObjectPtr result =
|
||||
::tensorflow::AttrValueToPyObject(attr_value);
|
||||
if (!result) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
Py_INCREF(result.get());
|
||||
return result.release();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Expose ConvertPyObjectToAttributeType via Python. Note: this is done to
|
||||
// simplify testing; ConvertPyObjectToAttributeType is expected to be called
|
||||
// directly from c++.
|
||||
// Expose op_def_util.h functions via Python.
|
||||
PYBIND11_MODULE(_op_def_util, m) {
|
||||
// Note: the bindings below are added for testing purposes; but the functions
|
||||
// are expected to be called from c++, not Python.
|
||||
m.def("ConvertPyObjectToAttributeType", ConvertAttr, py::arg("value"),
|
||||
py::arg("attr_type_enum"));
|
||||
m.def("SerializedAttrValueToPyObject", SerializedAttrValueToPyObject,
|
||||
py::arg("attr_value_string"));
|
||||
}
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for tensorflow.python.ops.op_def_library."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -23,6 +22,8 @@ from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import tensor_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python import _op_def_util
|
||||
@ -63,7 +64,7 @@ class OpDefUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
("list(int)", (1, 2.3), [1, 2]),
|
||||
("list(float)", (1, 2.3), [1.0, 2.3]),
|
||||
("list(bool)", [True, False], [True, False]),
|
||||
])
|
||||
]) # pyformat: disable
|
||||
def testConvert(self, attr_type, value, expected):
|
||||
result = _op_def_util.ConvertPyObjectToAttributeType(value, attr_type)
|
||||
|
||||
@ -93,6 +94,45 @@ class OpDefUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegex(TypeError, "Failed to convert value"):
|
||||
_op_def_util.ConvertPyObjectToAttributeType(value, attr_type)
|
||||
|
||||
# Test AttrValueToPyObject(). Note: this test also exercises the code in
|
||||
# DataTypeToPyObject() and TensorShapeToPyObject(), since those are used
|
||||
# when the AttrValue contains a DataType or TensorShape.
|
||||
@parameterized.parameters([
|
||||
("s: 'foo'", "foo"),
|
||||
("i: 5", 5),
|
||||
("f: 8", 8.0),
|
||||
("b: True", True),
|
||||
("type: DT_INT32", dtypes.int32),
|
||||
("shape { dim: [{size: 3}, {size: 4}] }",
|
||||
tensor_shape.TensorShape([3, 4])),
|
||||
("list { }", []),
|
||||
("list { s: [] }", []),
|
||||
("list { s: ['a', 'b', 'c'] }", ["a", "b", "c"]),
|
||||
("list { i: [1, 2, 3] }", [1, 2, 3]),
|
||||
("list { f: [2.0, 4.0] }", [2.0, 4.0]),
|
||||
]) # pyformat: disable
|
||||
def testAttrValueToPyObject(self, pbtxt, expected):
|
||||
proto = attr_value_pb2.AttrValue()
|
||||
text_format.Parse(pbtxt, proto)
|
||||
result = _op_def_util.SerializedAttrValueToPyObject(
|
||||
proto.SerializeToString())
|
||||
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
@parameterized.parameters([
|
||||
"", # Empty value (oneof not set)
|
||||
"tensor {}", # 'TensorProto' not supported (yet).
|
||||
"func {}", # 'func' not supported.
|
||||
"placeholder: ''", # 'placeholder' not supported.
|
||||
"list { tensor [{}] }", # 'TensorProto' not supported (yet).
|
||||
"list { func [{}] }", # 'func' not supported.
|
||||
]) # pyformat: disable
|
||||
def testAttrValueToPyObjectError(self, pbtxt):
|
||||
proto = attr_value_pb2.AttrValue()
|
||||
text_format.Parse(pbtxt, proto)
|
||||
with self.assertRaises((TypeError, ValueError)):
|
||||
_op_def_util.SerializedAttrValueToPyObject(proto.SerializeToString())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user