diff --git a/tensorflow/python/framework/op_def_util.cc b/tensorflow/python/framework/op_def_util.cc index c915c494be9..4e1569f190d 100644 --- a/tensorflow/python/framework/op_def_util.cc +++ b/tensorflow/python/framework/op_def_util.cc @@ -17,19 +17,28 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/python/lib/core/safe_ptr.h" #include "tensorflow/python/util/util.h" using ::tensorflow::swig::GetRegisteredPyObject; #if PY_MAJOR_VERSION < 3 +// Python 2.x: #define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x)) +#define PY_STRING_FROMSTRING(x) (PyString_FromString(x)) #define PY_INT_CHECK(x) (PyInt_Check(x)) #define PY_INT_TYPE PyInt_Type +#define PY_INT_FROM_LONG(x) (PyInt_FromLong(x)) #else +// Python 3.x: #define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x)) +#define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x)) #define PY_INT_CHECK(x) (PyLong_Check(x)) #define PY_INT_TYPE PyLong_Type +#define PY_INT_FROM_LONG(x) (PyLong_FromLong(x)) #endif namespace tensorflow { @@ -239,6 +248,64 @@ Safe_PyObjectPtr ConvertAttrOrNull(PyObject* value, AttributeType attr_type) { } } +// Returns a new reference to Py_True or Py_False depending on b. +PyObject* PyBool_FromBool(bool b) { + PyObject* result = b ? Py_True : Py_False; + Py_INCREF(result); + return result; +} + +Safe_PyObjectPtr AttrValueListToPyObject(AttrValue::ListValue list) { + if (list.s_size()) { + Safe_PyObjectPtr result(PyList_New(list.s_size())); + for (int i = 0; i < list.s_size(); ++i) { + PyList_SET_ITEM(result.get(), i, PY_STRING_FROMSTRING(list.s(i).c_str())); + } + return result; + } else if (list.i_size()) { + Safe_PyObjectPtr result(PyList_New(list.i_size())); + for (int i = 0; i < list.i_size(); ++i) { + PyList_SET_ITEM(result.get(), i, PY_INT_FROM_LONG(list.i(i))); + } + return result; + } else if (list.f_size()) { + Safe_PyObjectPtr result(PyList_New(list.f_size())); + for (int i = 0; i < list.f_size(); ++i) { + PyList_SET_ITEM(result.get(), i, PyFloat_FromDouble(list.f(i))); + } + return result; + } else if (list.b_size()) { + Safe_PyObjectPtr result(PyList_New(list.b_size())); + for (int i = 0; i < list.b_size(); ++i) { + PyList_SET_ITEM(result.get(), i, PyBool_FromBool(list.b(i))); + } + return result; + } else if (list.type_size()) { + Safe_PyObjectPtr result(PyList_New(list.type_size())); + for (int i = 0; i < list.type_size(); ++i) { + Safe_PyObjectPtr item(DataTypeToPyObject(list.type(i))); + Py_INCREF(item.get()); + PyList_SET_ITEM(result.get(), i, item.get()); + } + return result; + } else if (list.shape_size()) { + Safe_PyObjectPtr result(PyList_New(list.shape_size())); + for (int i = 0; i < list.shape_size(); ++i) { + Safe_PyObjectPtr item(TensorShapeProtoToPyObject(list.shape(i))); + Py_INCREF(item.get()); + PyList_SET_ITEM(result.get(), i, item.get()); + } + return result; + } else if (list.tensor_size() || list.func_size()) { + // TODO(edloper): Add support for tensorflow::AttrValue::kTensor. + PyErr_SetString(PyExc_TypeError, "Unsupported AttrValue type"); + return nullptr; + } else { + // Empty list + return Safe_PyObjectPtr(PyList_New(0)); + } +} + } // namespace AttributeType AttributeTypeFromName(const std::string& type_name) { @@ -269,4 +336,46 @@ Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value, return result; } +Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value) { + switch (attr_value.value_case()) { + case tensorflow::AttrValue::kS: + return Safe_PyObjectPtr(PY_STRING_FROMSTRING(attr_value.s().c_str())); + case tensorflow::AttrValue::kI: + return Safe_PyObjectPtr(PY_INT_FROM_LONG(attr_value.i())); + case tensorflow::AttrValue::kF: + return Safe_PyObjectPtr(PyFloat_FromDouble(attr_value.f())); + case tensorflow::AttrValue::kB: + return Safe_PyObjectPtr(PyBool_FromBool(attr_value.b())); + case tensorflow::AttrValue::kType: + return DataTypeToPyObject(attr_value.type()); + case tensorflow::AttrValue::kShape: + return TensorShapeProtoToPyObject(attr_value.shape()); + case tensorflow::AttrValue::kList: + return AttrValueListToPyObject(attr_value.list()); + default: + // TODO(edloper): Add support for tensorflow::AttrValue::kTensor. + PyErr_SetString(PyExc_ValueError, "Unsupported AttrValue type"); + return nullptr; + } +} + +Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type) { + Safe_PyObjectPtr enum_value(PY_INT_FROM_LONG(data_type)); + return ConvertDTypeFunctor()(enum_value.get()); +} + +Safe_PyObjectPtr TensorShapeProtoToPyObject( + const TensorShapeProto& tensor_shape) { + if (tensor_shape.unknown_rank()) { + return ConvertTensorShapeFunctor()(Py_None); + } else { + Safe_PyObjectPtr dims(PyTuple_New(tensor_shape.dim_size())); + for (int i = 0; i < tensor_shape.dim_size(); ++i) { + PyTuple_SET_ITEM(dims.get(), i, + PY_INT_FROM_LONG(tensor_shape.dim(i).size())); + } + return ConvertTensorShapeFunctor()(dims.get()); + } +} + } // namespace tensorflow diff --git a/tensorflow/python/framework/op_def_util.h b/tensorflow/python/framework/op_def_util.h index ef5e64e68fa..3b35c3ef7ad 100644 --- a/tensorflow/python/framework/op_def_util.h +++ b/tensorflow/python/framework/op_def_util.h @@ -15,8 +15,14 @@ limitations under the License. #ifndef TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_ #define TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_ +#include + #include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/python/lib/core/safe_ptr.h" namespace tensorflow { @@ -72,6 +78,21 @@ std::string AttributeTypeToName(AttributeType attr_type); Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value, AttributeType type); +// Converts a c++ `AttrValue` protobuf message to a Python object; or sets a +// Python exception and returns nullptr if an error occurs. +Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value); + +// Converts a c++ `DataType` protobuf enum to a Python object; or sets a +// Python exception and returns nullptr if an error occurs. +Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type); + +// Converts a c++ `TensorShapeProto` message to a Python object; or sets a +// Python exception and returns nullptr if an error occurs. +Safe_PyObjectPtr TensorShapeProtoToPyObject( + const TensorShapeProto& tensor_shape); + +// TODO(edloper): Define TensorProtoToPyObject? + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_ diff --git a/tensorflow/python/framework/op_def_util_pybind.cc b/tensorflow/python/framework/op_def_util_pybind.cc index d13f605b599..a7843322840 100644 --- a/tensorflow/python/framework/op_def_util_pybind.cc +++ b/tensorflow/python/framework/op_def_util_pybind.cc @@ -30,12 +30,26 @@ py::handle ConvertAttr(py::handle value, std::string attr_type) { return result.release(); } +py::handle SerializedAttrValueToPyObject(std::string attr_value_string) { + tensorflow::AttrValue attr_value; + attr_value.ParseFromString(attr_value_string); + tensorflow::Safe_PyObjectPtr result = + ::tensorflow::AttrValueToPyObject(attr_value); + if (!result) { + throw py::error_already_set(); + } + Py_INCREF(result.get()); + return result.release(); +} + } // namespace -// Expose ConvertPyObjectToAttributeType via Python. Note: this is done to -// simplify testing; ConvertPyObjectToAttributeType is expected to be called -// directly from c++. +// Expose op_def_util.h functions via Python. PYBIND11_MODULE(_op_def_util, m) { + // Note: the bindings below are added for testing purposes; but the functions + // are expected to be called from c++, not Python. m.def("ConvertPyObjectToAttributeType", ConvertAttr, py::arg("value"), py::arg("attr_type_enum")); + m.def("SerializedAttrValueToPyObject", SerializedAttrValueToPyObject, + py::arg("attr_value_string")); } diff --git a/tensorflow/python/framework/op_def_util_test.py b/tensorflow/python/framework/op_def_util_test.py index 69aaffbf19f..9f2ce61996f 100644 --- a/tensorflow/python/framework/op_def_util_test.py +++ b/tensorflow/python/framework/op_def_util_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Tests for tensorflow.python.ops.op_def_library.""" from __future__ import absolute_import @@ -23,6 +22,8 @@ from absl.testing import parameterized import numpy as np +from google.protobuf import text_format +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import tensor_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.python import _op_def_util @@ -63,7 +64,7 @@ class OpDefUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): ("list(int)", (1, 2.3), [1, 2]), ("list(float)", (1, 2.3), [1.0, 2.3]), ("list(bool)", [True, False], [True, False]), - ]) + ]) # pyformat: disable def testConvert(self, attr_type, value, expected): result = _op_def_util.ConvertPyObjectToAttributeType(value, attr_type) @@ -93,6 +94,45 @@ class OpDefUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaisesRegex(TypeError, "Failed to convert value"): _op_def_util.ConvertPyObjectToAttributeType(value, attr_type) + # Test AttrValueToPyObject(). Note: this test also exercises the code in + # DataTypeToPyObject() and TensorShapeToPyObject(), since those are used + # when the AttrValue contains a DataType or TensorShape. + @parameterized.parameters([ + ("s: 'foo'", "foo"), + ("i: 5", 5), + ("f: 8", 8.0), + ("b: True", True), + ("type: DT_INT32", dtypes.int32), + ("shape { dim: [{size: 3}, {size: 4}] }", + tensor_shape.TensorShape([3, 4])), + ("list { }", []), + ("list { s: [] }", []), + ("list { s: ['a', 'b', 'c'] }", ["a", "b", "c"]), + ("list { i: [1, 2, 3] }", [1, 2, 3]), + ("list { f: [2.0, 4.0] }", [2.0, 4.0]), + ]) # pyformat: disable + def testAttrValueToPyObject(self, pbtxt, expected): + proto = attr_value_pb2.AttrValue() + text_format.Parse(pbtxt, proto) + result = _op_def_util.SerializedAttrValueToPyObject( + proto.SerializeToString()) + + self.assertEqual(expected, result) + + @parameterized.parameters([ + "", # Empty value (oneof not set) + "tensor {}", # 'TensorProto' not supported (yet). + "func {}", # 'func' not supported. + "placeholder: ''", # 'placeholder' not supported. + "list { tensor [{}] }", # 'TensorProto' not supported (yet). + "list { func [{}] }", # 'func' not supported. + ]) # pyformat: disable + def testAttrValueToPyObjectError(self, pbtxt): + proto = attr_value_pb2.AttrValue() + text_format.Parse(pbtxt, proto) + with self.assertRaises((TypeError, ValueError)): + _op_def_util.SerializedAttrValueToPyObject(proto.SerializeToString()) + + if __name__ == "__main__": googletest.main() -