STT-tensorflow/tensorflow/python/framework/op_def_util.cc
Edward Loper 35e84682fa Narrow dependencies for op_def_util from safe_ptr.h to safe_pyobject_ptr.h
PiperOrigin-RevId: 336747454
Change-Id: Idef5c4b26cb5676611753d97d54c337323c9c4f3
2020-10-12 15:10:05 -07:00

382 lines
13 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/framework/op_def_util.h"
#include <map>
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
#include "tensorflow/python/util/util.h"
using ::tensorflow::swig::GetRegisteredPyObject;
#if PY_MAJOR_VERSION < 3
// Python 2.x:
#define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
#define PY_STRING_FROMSTRING(x) (PyString_FromString(x))
#define PY_INT_CHECK(x) (PyInt_Check(x))
#define PY_INT_TYPE PyInt_Type
#define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
#else
// Python 3.x:
#define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
#define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x))
#define PY_INT_CHECK(x) (PyLong_Check(x))
#define PY_INT_TYPE PyLong_Type
#define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
#endif
namespace tensorflow {
namespace {
const std::map<std::string, AttributeType>* AttributeTypeNameMap() {
static auto* type_map = new std::map<std::string, AttributeType>(
{{"any", AttributeType::ANY},
{"float", AttributeType::FLOAT},
{"int", AttributeType::INT},
{"string", AttributeType::STRING},
{"bool", AttributeType::BOOL},
{"shape", AttributeType::SHAPE},
{"type", AttributeType::DTYPE},
{"tensor", AttributeType::TENSOR},
{"list(any)", AttributeType::LIST_ANY},
{"list(float)", AttributeType::LIST_FLOAT},
{"list(int)", AttributeType::LIST_INT},
{"list(string)", AttributeType::LIST_STRING},
{"list(bool)", AttributeType::LIST_BOOL},
{"list(type)", AttributeType::LIST_DTYPE},
{"list(shape)", AttributeType::LIST_SHAPE},
{"list(tensor)", AttributeType::LIST_TENSOR}});
return type_map;
}
// Note: we define functors for converting value types (rather than simple
// functions) so we can define a generic ConvertListAttr method. These
// functors all return a new reference on success, or nullptr on failure.
// They do not (necessarily) call PyErr_SetString.
struct ConvertAnyFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Py_INCREF(value);
return Safe_PyObjectPtr(value);
}
};
struct ConvertFloatFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
if (PyFloat_Check(value)) {
Py_INCREF(value);
result.reset(value);
} else if (!PY_STRING_CHECK(value)) {
result.reset(PyObject_CallFunctionObjArgs(
reinterpret_cast<PyObject*>(&PyFloat_Type), value, nullptr));
}
return result;
}
};
struct ConvertIntFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
if (PY_INT_CHECK(value)) {
Py_INCREF(value);
result.reset(value);
} else if (!PY_STRING_CHECK(value)) {
result.reset(PyObject_CallFunctionObjArgs(
reinterpret_cast<PyObject*>(&PY_INT_TYPE), value, nullptr));
}
return result;
}
};
struct ConvertStringFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
if (PY_STRING_CHECK(value)) {
Py_INCREF(value);
result.reset(value);
}
return result;
}
};
// TODO(edloper): Should we allow ints (or any other values) to be converted
// to booleans? Currently, TensorFlow does not do this conversion for attribute
// values in _MakeBool or make_bool.
struct ConvertBoolFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
if (PyBool_Check(value)) {
Py_INCREF(value);
result.reset(value);
}
return result;
}
};
struct ConvertDTypeFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
// The following symbols are registered in op_def_library.py
static PyObject* dtype = GetRegisteredPyObject("tf.dtypes.DType");
static PyObject* as_dtype = GetRegisteredPyObject("tf.dtypes.as_dtype");
if (reinterpret_cast<PyObject*>(value->ob_type) == dtype) {
Py_INCREF(value);
result.reset(value);
} else {
result.reset(PyObject_CallFunctionObjArgs(as_dtype, value, nullptr));
}
return result;
}
};
struct ConvertTensorShapeFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
// The following symbols are registered in op_def_library.py
static PyObject* shape = GetRegisteredPyObject("tf.TensorShape");
static PyObject* as_shape = GetRegisteredPyObject("tf.as_shape");
if (reinterpret_cast<PyObject*>(value->ob_type) == shape) {
Py_INCREF(value);
result.reset(value);
} else {
result.reset(PyObject_CallFunctionObjArgs(as_shape, value, nullptr));
}
return result;
}
};
struct ConvertTensorProtoFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
// The following symbols are registered in op_def_library.py
static PyObject* tensor_proto = GetRegisteredPyObject("tf.TensorProto");
static PyObject* text_format_parse =
GetRegisteredPyObject("text_format.Parse");
if (reinterpret_cast<PyObject*>(value->ob_type) == tensor_proto) {
Py_INCREF(value);
result.reset(value);
} else if (PY_STRING_CHECK(value)) {
result.reset(PyObject_CallObject(tensor_proto, nullptr));
if (result) {
if (!PyObject_CallFunctionObjArgs(text_format_parse, value,
result.get(), nullptr)) {
return nullptr;
}
}
}
return result;
}
};
// Converts `value` to a list of elements with the same type, using
// `convert_functor` to convert each element.
template <typename T>
Safe_PyObjectPtr ConvertListAttr(PyObject* value, T convert_functor) {
// Copy the list.
Safe_PyObjectPtr result(PySequence_List(value));
if (!result) return nullptr;
// Check the type of each item in the list.
Py_ssize_t len = PySequence_Fast_GET_SIZE(result.get());
PyObject** items = PySequence_Fast_ITEMS(result.get());
for (Py_ssize_t i = 0; i < len; ++i) {
if (!PyFloat_Check(value)) {
Safe_PyObjectPtr item = convert_functor(items[i]);
if (!item) return nullptr;
PySequence_SetItem(result.get(), i, item.get());
}
}
return result;
}
// Returns the given `value` value, converted to the indicated type.
// Returns nullptr if `value` is not convertible.
Safe_PyObjectPtr ConvertAttrOrNull(PyObject* value, AttributeType attr_type) {
switch (attr_type) {
case AttributeType::ANY:
return ConvertAnyFunctor()(value);
case AttributeType::FLOAT:
return ConvertFloatFunctor()(value);
case AttributeType::INT:
return ConvertIntFunctor()(value);
case AttributeType::STRING:
return ConvertStringFunctor()(value);
case AttributeType::BOOL:
return ConvertBoolFunctor()(value);
case AttributeType::DTYPE:
return ConvertDTypeFunctor()(value);
case AttributeType::SHAPE:
return ConvertTensorShapeFunctor()(value);
case AttributeType::TENSOR:
return ConvertTensorProtoFunctor()(value);
case AttributeType::LIST_ANY:
return ConvertListAttr(value, ConvertAnyFunctor());
case AttributeType::LIST_FLOAT:
return ConvertListAttr(value, ConvertFloatFunctor());
case AttributeType::LIST_INT:
return ConvertListAttr(value, ConvertIntFunctor());
case AttributeType::LIST_STRING:
return ConvertListAttr(value, ConvertStringFunctor());
case AttributeType::LIST_BOOL:
return ConvertListAttr(value, ConvertBoolFunctor());
case AttributeType::LIST_DTYPE:
return ConvertListAttr(value, ConvertDTypeFunctor());
case AttributeType::LIST_SHAPE:
return ConvertListAttr(value, ConvertTensorShapeFunctor());
case AttributeType::LIST_TENSOR:
return ConvertListAttr(value, ConvertTensorProtoFunctor());
default:
return nullptr;
}
}
// Returns a new reference to Py_True or Py_False depending on b.
PyObject* PyBool_FromBool(bool b) {
PyObject* result = b ? Py_True : Py_False;
Py_INCREF(result);
return result;
}
Safe_PyObjectPtr AttrValueListToPyObject(AttrValue::ListValue list) {
if (list.s_size()) {
Safe_PyObjectPtr result(PyList_New(list.s_size()));
for (int i = 0; i < list.s_size(); ++i) {
PyList_SET_ITEM(result.get(), i, PY_STRING_FROMSTRING(list.s(i).c_str()));
}
return result;
} else if (list.i_size()) {
Safe_PyObjectPtr result(PyList_New(list.i_size()));
for (int i = 0; i < list.i_size(); ++i) {
PyList_SET_ITEM(result.get(), i, PY_INT_FROM_LONG(list.i(i)));
}
return result;
} else if (list.f_size()) {
Safe_PyObjectPtr result(PyList_New(list.f_size()));
for (int i = 0; i < list.f_size(); ++i) {
PyList_SET_ITEM(result.get(), i, PyFloat_FromDouble(list.f(i)));
}
return result;
} else if (list.b_size()) {
Safe_PyObjectPtr result(PyList_New(list.b_size()));
for (int i = 0; i < list.b_size(); ++i) {
PyList_SET_ITEM(result.get(), i, PyBool_FromBool(list.b(i)));
}
return result;
} else if (list.type_size()) {
Safe_PyObjectPtr result(PyList_New(list.type_size()));
for (int i = 0; i < list.type_size(); ++i) {
Safe_PyObjectPtr item(DataTypeToPyObject(list.type(i)));
Py_INCREF(item.get());
PyList_SET_ITEM(result.get(), i, item.get());
}
return result;
} else if (list.shape_size()) {
Safe_PyObjectPtr result(PyList_New(list.shape_size()));
for (int i = 0; i < list.shape_size(); ++i) {
Safe_PyObjectPtr item(TensorShapeProtoToPyObject(list.shape(i)));
Py_INCREF(item.get());
PyList_SET_ITEM(result.get(), i, item.get());
}
return result;
} else if (list.tensor_size() || list.func_size()) {
// TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
PyErr_SetString(PyExc_TypeError, "Unsupported AttrValue type");
return nullptr;
} else {
// Empty list
return Safe_PyObjectPtr(PyList_New(0));
}
}
} // namespace
AttributeType AttributeTypeFromName(const std::string& type_name) {
const auto* type_map = AttributeTypeNameMap();
auto it = type_map->find(type_name);
return it != type_map->end() ? it->second : AttributeType::UNKNOWN;
}
std::string AttributeTypeToName(AttributeType attr_type) {
for (const auto& pair : *AttributeTypeNameMap()) {
if (pair.second == attr_type) {
return pair.first;
}
}
return "<unknown>";
}
Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
AttributeType type) {
Safe_PyObjectPtr result = ConvertAttrOrNull(value, type);
if (!result) {
auto err = absl::StrCat("Failed to convert value of type '",
value->ob_type->tp_name, "' to type '",
AttributeTypeToName(type), "'.");
PyErr_SetString(PyExc_TypeError, err.c_str());
}
return result;
}
Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value) {
switch (attr_value.value_case()) {
case tensorflow::AttrValue::kS:
return Safe_PyObjectPtr(PY_STRING_FROMSTRING(attr_value.s().c_str()));
case tensorflow::AttrValue::kI:
return Safe_PyObjectPtr(PY_INT_FROM_LONG(attr_value.i()));
case tensorflow::AttrValue::kF:
return Safe_PyObjectPtr(PyFloat_FromDouble(attr_value.f()));
case tensorflow::AttrValue::kB:
return Safe_PyObjectPtr(PyBool_FromBool(attr_value.b()));
case tensorflow::AttrValue::kType:
return DataTypeToPyObject(attr_value.type());
case tensorflow::AttrValue::kShape:
return TensorShapeProtoToPyObject(attr_value.shape());
case tensorflow::AttrValue::kList:
return AttrValueListToPyObject(attr_value.list());
default:
// TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
PyErr_SetString(PyExc_ValueError, "Unsupported AttrValue type");
return nullptr;
}
}
Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type) {
Safe_PyObjectPtr enum_value(PY_INT_FROM_LONG(data_type));
return ConvertDTypeFunctor()(enum_value.get());
}
Safe_PyObjectPtr TensorShapeProtoToPyObject(
const TensorShapeProto& tensor_shape) {
if (tensor_shape.unknown_rank()) {
return ConvertTensorShapeFunctor()(Py_None);
} else {
Safe_PyObjectPtr dims(PyTuple_New(tensor_shape.dim_size()));
for (int i = 0; i < tensor_shape.dim_size(); ++i) {
PyTuple_SET_ITEM(dims.get(), i,
PY_INT_FROM_LONG(tensor_shape.dim(i).size()));
}
return ConvertTensorShapeFunctor()(dims.get());
}
}
} // namespace tensorflow