STT-tensorflow/tensorflow/python/framework/python_api_info.cc
Edward Loper 30b69242f8 Class to hold precomputed information about a TensorFlow Python API, to allow the API to be executed rapidly.
PiperOrigin-RevId: 341499017
Change-Id: I875ea89efcd86a7fe9e2f8fcefab1cbd3aa2c0e9
2020-11-09 15:53:35 -08:00

509 lines
18 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/framework/python_api_info.h"
#include <Python.h>
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/framework/op_def_util.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
#include "tensorflow/python/util/util.h"
namespace tensorflow {
#if PY_MAJOR_VERSION < 3
// Python 2.x:
#define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
#define PY_STRING_FROMSTRING(x) (PyString_FromString(x))
#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
#define PY_STRING_AS_CSTR(x) (PyString_AsString(x))
#else
// Python 3.x:
#define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
#define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x))
#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
#define PY_STRING_AS_CSTR(x) (PyUnicode_AsUTF8AndSize((x), nullptr))
#endif
namespace {
// Converts the given object to an interned Python string, and returns its
// data pointer. (This means we don't need to worry about ownership for
// this string.)
const char* InternPyString(const std::string& s) {
Safe_PyObjectPtr interned(PY_STRING_INTERN_FROM_STRING(s.c_str()));
return PY_STRING_AS_CSTR(interned.get());
}
template <typename T, typename UnaryPredicate>
void RemoveIf(UnaryPredicate p, std::vector<T>* vec) {
vec->erase(std::remove_if(vec->begin(), vec->end(), p), vec->end());
}
struct DataTypeFormatter {
void operator()(std::string* out, DataType dtype) const {
out->append(DataType_Name(dtype));
}
};
// Populates `param_names` and `defaults_tuple` based on the given OpDef.
void GetOpDefNamesAndDefaults(const tensorflow::OpDef& op_def,
std::vector<string>& param_names,
Safe_PyObjectPtr& defaults_tuple) {
param_names.reserve(op_def.input_arg_size() + op_def.attr_size());
std::set<std::string> inferred_attrs;
// Input parameters come first, in the order they occur in the OpDef.
for (const auto& input : op_def.input_arg()) {
param_names.push_back(input.name());
if (!input.type_attr().empty()) {
inferred_attrs.insert(input.type_attr());
}
if (!input.type_list_attr().empty()) {
inferred_attrs.insert(input.type_list_attr());
}
if (!input.number_attr().empty()) {
inferred_attrs.insert(input.number_attr());
}
}
// Next come attribute params without defaults, followed by attributes with
// defaults (but inferred attributes are not included).
std::vector<std::string> param_names_with_default;
std::vector<Safe_PyObjectPtr> defaults;
for (const auto& attr : op_def.attr()) {
if (inferred_attrs.count(attr.name()) == 0) {
if (attr.has_default_value()) {
param_names_with_default.push_back(attr.name());
defaults.push_back(AttrValueToPyObject(attr.default_value()));
} else {
param_names.push_back(attr.name());
}
}
}
param_names.insert(param_names.end(), param_names_with_default.begin(),
param_names_with_default.end());
// Finally, the 'name' parameter comes at the end, and its default value
// is the operation's name.
param_names.push_back("name");
defaults.emplace_back(PY_STRING_FROMSTRING(op_def.name().c_str()));
defaults_tuple.reset(PyTuple_New(defaults.size()));
for (int i = 0; i < defaults.size(); ++i) {
PyTuple_SET_ITEM(defaults_tuple.get(), i, defaults[i].release());
}
}
} // namespace
PythonAPIInfo::PythonAPIInfo(const std::string& api_name)
: api_name_(InternPyString(api_name)) {}
Status PythonAPIInfo::Initialize(const OpDef& op_def,
const std::vector<string> param_names,
PyObject* defaults_tuple) {
// Intern the parameter names.
param_names_.reserve(param_names.size());
for (const auto& param_name : param_names) {
param_names_.push_back(InternPyString(param_name));
}
Py_INCREF(defaults_tuple);
defaults_tuple_.reset(defaults_tuple);
// Build an index to look up parameter index by name. (Does not include
// inferred attributes.)
std::map<std::string, int> param_name_to_index;
for (int i = 0; i < param_names_.size(); ++i) {
param_name_to_index[param_names_[i]] = i;
}
// Initialize each attribute & input parameter.
attributes_.reserve(op_def.attr_size());
for (const auto& attr_def : op_def.attr()) {
TF_RETURN_IF_ERROR(InitializeAttribute(attr_def, param_name_to_index));
}
inputs_.reserve(op_def.input_arg_size());
for (const auto& arg_def : op_def.input_arg()) {
TF_RETURN_IF_ERROR(InitializeInput(arg_def, param_name_to_index));
}
TF_RETURN_IF_ERROR(CheckParamNames());
// Filter out any unused entries from inputs_with_*_attrs_.
RemoveIf(
[](const InputsWithTypeAttr& input) {
return input.tensor_params.empty() && input.tensor_list_params.empty();
},
&inputs_with_type_attrs_);
RemoveIf(
[](const InputsWithTypeListAttr& input) {
return input.tensor_list_params.empty();
},
&inputs_with_type_list_attrs_);
RemoveIf(
[](const InputsWithNumberAttr& input) {
return input.tensor_list_params.empty();
},
&inputs_with_number_attrs_);
return Status::OK();
}
Status PythonAPIInfo::CheckParamNames() const {
std::vector<bool> param_found(param_names_.size());
for (const auto& attr : attributes_) {
if (attr.index != -1) {
param_found[attr.index] = true;
}
}
for (const auto& input : inputs_) {
param_found[input.index] = true;
}
for (int i = 0; i < param_names_.size(); ++i) {
if (param_names_[i] == std::string("name")) {
continue;
}
if (!param_found[i]) {
return errors::InvalidArgument(
api_name_, ": missing specification for parameter ", param_names_[i]);
}
}
return Status::OK();
}
Status PythonAPIInfo::InitializeFromRegisteredOp(const std::string& op_name) {
const tensorflow::OpDef* op_def = nullptr;
TF_RETURN_IF_ERROR(
tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def));
std::vector<std::string> param_names;
Safe_PyObjectPtr defaults_tuple;
GetOpDefNamesAndDefaults(*op_def, param_names, defaults_tuple);
TF_RETURN_IF_ERROR(Initialize(*op_def, param_names, defaults_tuple.get()));
return Status::OK();
}
Status PythonAPIInfo::InitializeFromParamSpecs(
const std::map<std::string, std::string>& input_specs,
const std::map<std::string, std::string>& attr_specs,
const std::vector<string> param_names, PyObject* defaults_tuple) {
OpDefBuilder op_def_builder(api_name_);
op_def_builder.AllowAttrTypeAny();
for (const auto& attr_spec : attr_specs) {
op_def_builder.Attr(absl::StrCat(attr_spec.first, ": ", attr_spec.second));
}
for (const auto& input_spec : input_specs) {
op_def_builder.Input(
absl::StrCat(input_spec.first, ": ", input_spec.second));
}
OpRegistrationData op_reg_data;
TF_RETURN_IF_ERROR(op_def_builder.Finalize(&op_reg_data));
TF_RETURN_IF_ERROR(
Initialize(op_reg_data.op_def, param_names, defaults_tuple));
return Status::OK();
}
Status PythonAPIInfo::InitializeAttribute(
const OpDef::AttrDef& attr_def,
const std::map<std::string, int>& param_name_to_index) {
if (attr_def.name() == "name") {
return errors::InvalidArgument(
api_name_, ": Reserved parameter `name` was used as an attribute.");
}
const char* name = InternPyString(attr_def.name());
const int param_index =
gtl::FindWithDefault(param_name_to_index, attr_def.name(), -1);
const AttributeType dtype = AttributeTypeFromName(attr_def.type());
const int inferred_index = -1;
attributes_.push_back({param_index, dtype, name, inferred_index});
Attribute& attr = attributes_.back();
if (attr.type == AttributeType::UNKNOWN) {
return errors::InvalidArgument(api_name_, ": Bad attribute type for ",
attr_def.name(), ": '", attr_def.type(),
"'");
}
std::vector<DataType>* ok_dtypes = nullptr;
if (attr.type == AttributeType::DTYPE) {
DataType default_dtype = attr_def.has_default_value()
? attr_def.default_value().type()
: DT_INVALID;
inputs_with_type_attrs_.push_back({&attr, default_dtype});
ok_dtypes = &inputs_with_type_attrs_.back().ok_dtypes;
} else if (attr.type == AttributeType::LIST_DTYPE) {
inputs_with_type_list_attrs_.push_back({&attr});
for (int d : attr_def.default_value().list().type()) {
inputs_with_type_list_attrs_.back().default_dtypes.push_back(
static_cast<DataType>(d));
}
ok_dtypes = &inputs_with_type_list_attrs_.back().ok_dtypes;
}
if (attr_def.has_allowed_values() && ok_dtypes) {
const auto& dtypes = attr_def.allowed_values().list();
for (int i = 0; i < dtypes.type_size(); ++i) {
ok_dtypes->push_back(dtypes.type(i));
}
}
if (attr.type == AttributeType::INT) {
int64 default_len =
attr_def.has_default_value() ? attr_def.default_value().i() : -1;
inputs_with_number_attrs_.push_back({&attr, default_len});
}
// If this is an inferred attribute, then record its name and index.
if (attr.index == -1) {
std::vector<const char*>* inferred_attr_names =
attr.type == AttributeType::DTYPE ? &inferred_type_attrs_
: attr.type == AttributeType::LIST_DTYPE ? &inferred_type_list_attrs_
: attr.type == AttributeType::INT ? &inferred_length_attrs_
: nullptr;
if (inferred_attr_names == nullptr) {
return errors::InvalidArgument(
api_name_, ": Missing specification for parameter ", attr_def.name());
} else {
attr.inferred_index = inferred_attr_names->size();
inferred_attr_names->push_back(attr.name);
}
}
return Status::OK();
}
Status PythonAPIInfo::InitializeInput(
const OpDef::ArgDef& arg_def,
const std::map<std::string, ParamIndex>& param_name_to_index) {
if (arg_def.name() == "name") {
return errors::InvalidArgument(
api_name_, ": Reserved parameter `name` was used as a tensor input.");
}
const ParamIndex param_index =
gtl::FindWithDefault(param_name_to_index, arg_def.name(), -1);
if (param_index == -1) {
return errors::InvalidArgument(
api_name_, ": Missing specification for parameter ", arg_def.name());
}
if (arg_def.is_ref()) {
// TODO(b/164980194): Support reference parameters.
// - Pass as_ref to convert_to_tensor
// - Check that values for ref inputs have ref types.
return errors::InvalidArgument(api_name_,
": PythonAPIInfo doesn't support reference "
"parameters yet.");
}
bool is_list =
!arg_def.number_attr().empty() || !arg_def.type_list_attr().empty();
inputs_.push_back({param_index, is_list});
if (!arg_def.type_list_attr().empty()) {
// list(input) with dtypes specified by a `list(type)` attribute.
InputsWithTypeListAttr* input =
FindInputsWithTypeListAttr(arg_def.type_list_attr());
if (!input) {
return errors::InvalidArgument(
api_name_, ": Type attribute ", arg_def.type_list_attr(),
" for parameter ", arg_def.name(), " not found.");
}
input->tensor_list_params.push_back(param_index);
} else if (!arg_def.type_attr().empty()) {
InputsWithTypeAttr* input = FindInputsWithTypeAttr(arg_def.type_attr());
// input or list(input) with dtype specified by a `type` attribute.
if (!input) {
return errors::InvalidArgument(api_name_, ": Type attribute ",
arg_def.type_attr(), " for parameter ",
arg_def.name(), " not found.");
}
if (arg_def.number_attr().empty()) {
input->tensor_params.push_back(param_index);
} else {
input->tensor_list_params.push_back(param_index);
}
} else {
// input or list(input) with fixed dtype
inputs_with_fixed_dtype_.push_back({arg_def.type(), param_index, is_list});
}
if (!arg_def.number_attr().empty()) {
InputsWithNumberAttr* input =
FindInputsWithNumberAttr(arg_def.number_attr());
if (!input) {
return errors::InvalidArgument(api_name_, ": Length attribute ",
arg_def.number_attr(), " for parameter ",
arg_def.name(), " not found.");
}
input->tensor_list_params.push_back(param_index);
}
return Status::OK();
}
PythonAPIInfo::InputsWithTypeAttr* PythonAPIInfo::FindInputsWithTypeAttr(
const string& name) {
for (auto& input : inputs_with_type_attrs_) {
if (name == input.type_attr->name) {
return &input;
}
}
return nullptr;
}
PythonAPIInfo::InputsWithTypeListAttr*
PythonAPIInfo::FindInputsWithTypeListAttr(const string& name) {
for (auto& input : inputs_with_type_list_attrs_) {
if (name == input.type_list_attr->name) {
return &input;
}
}
return nullptr;
}
PythonAPIInfo::InputsWithNumberAttr* PythonAPIInfo::FindInputsWithNumberAttr(
const string& name) {
for (auto& input : inputs_with_number_attrs_) {
if (name == input.number_attr->name) {
return &input;
}
}
return nullptr;
}
string PythonAPIInfo::DebugInfo() const {
string s = absl::StrCat("DebugInfo for ", api_name_, ":\n");
absl::StrAppend(&s, " param_names=[", absl::StrJoin(param_names_, ", "),
"]\n");
Safe_PyObjectPtr defaults_repr(PyObject_Repr(defaults_tuple_.get()));
absl::StrAppend(
&s, " defaults_tuple=", TFE_GetPythonString(defaults_repr.get()), "\n");
if (!attributes_.empty()) {
absl::StrAppend(&s, " attributes=[");
for (const auto& attrib : attributes_) {
if (attrib.index != -1) {
absl::StrAppend(&s, "\n {index=", attrib.index);
DCHECK_EQ(attrib.inferred_index, -1);
} else {
absl::StrAppend(&s, "\n {inferred_index=", attrib.inferred_index);
}
absl::StrAppend(&s, ", name=", attrib.name,
", type=", AttributeTypeToName(attrib.type), "},");
}
absl::StrAppend(&s, "]\n");
}
if (!inputs_.empty()) {
absl::StrAppend(&s, " inputs=[");
for (const auto& input : inputs_) {
absl::StrAppend(&s, "\n {index=", input.index,
", name=", param_names_[input.index],
", is_list=", input.is_list, "},");
}
absl::StrAppend(&s, "]\n");
}
if (!inputs_with_fixed_dtype_.empty()) {
absl::StrAppend(&s, " inputs_with_fixed_dtype=[");
for (const auto& input : inputs_with_fixed_dtype_) {
absl::StrAppend(&s, "\n {index=", input.index,
", dtype=", DataType_Name(input.dtype),
", is_list=", input.is_list, "},");
}
absl::StrAppend(&s, "]\n");
}
if (!inputs_with_type_attrs_.empty()) {
absl::StrAppend(&s, " inputs_with_type_attr=[");
for (const auto& input : inputs_with_type_attrs_) {
absl::StrAppend(&s, "\n {type_attr=", input.type_attr->name);
if (input.default_dtype != DT_INVALID) {
absl::StrAppend(&s,
", default_dtype=", DataType_Name(input.default_dtype));
}
if (!input.tensor_params.empty()) {
absl::StrAppend(&s, ", tensor_params=[",
absl::StrJoin(input.tensor_params, ", "), "]");
}
if (!input.tensor_list_params.empty()) {
absl::StrAppend(&s, ", tensor_list_params=[",
absl::StrJoin(input.tensor_list_params, ", "), "]");
}
if (!input.ok_dtypes.empty()) {
absl::StrAppend(
&s, ", ok_dtypes=[",
absl::StrJoin(input.ok_dtypes, ", ", DataTypeFormatter()), "]");
}
absl::StrAppend(&s, "},");
}
absl::StrAppend(&s, "]\n");
}
if (!inputs_with_type_list_attrs_.empty()) {
absl::StrAppend(&s, " inputs_with_type_list_attrs=[");
for (const auto& input : inputs_with_type_list_attrs_) {
absl::StrAppend(&s, "\n {type_list_attr=", input.type_list_attr->name);
if (!input.default_dtypes.empty()) {
absl::StrAppend(
&s, ", default_dtypes=[",
absl::StrJoin(input.default_dtypes, ", ", DataTypeFormatter()),
"]");
}
if (!input.tensor_list_params.empty()) {
absl::StrAppend(&s, ", tensor_list_params=[",
absl::StrJoin(input.tensor_list_params, ", "), "]");
}
if (!input.ok_dtypes.empty()) {
absl::StrAppend(
&s, ", ok_dtypes=[",
absl::StrJoin(input.ok_dtypes, ", ", DataTypeFormatter()), "]");
}
absl::StrAppend(&s, "},");
}
absl::StrAppend(&s, "]\n");
}
if (!inputs_with_number_attrs_.empty()) {
absl::StrAppend(&s, " inputs_with_number_attrs=[");
for (const auto& input : inputs_with_number_attrs_) {
absl::StrAppend(&s, "\n {number_attr=", input.number_attr->name,
", default_length=", input.default_length,
", tensor_list_params=[",
absl::StrJoin(input.tensor_list_params, ", "), "],\n");
}
absl::StrAppend(&s, "]\n");
}
if (!inferred_type_attrs_.empty()) {
absl::StrAppend(&s, " inferred_type_attrs=[",
absl::StrJoin(inferred_type_attrs_, ", "), "]\n");
}
if (!inferred_type_list_attrs_.empty()) {
absl::StrAppend(&s, " inferred_type_list_attrs=[",
absl::StrJoin(inferred_type_list_attrs_, ", "), "]\n");
}
if (!inferred_length_attrs_.empty()) {
absl::StrAppend(&s, " inferred_length_attrs=[",
absl::StrJoin(inferred_length_attrs_, ", "), "]\n");
}
return s;
}
} // namespace tensorflow