Class to hold precomputed information about a TensorFlow Python API, to allow the API to be executed rapidly.
PiperOrigin-RevId: 341499017 Change-Id: I875ea89efcd86a7fe9e2f8fcefab1cbd3aa2c0e9
This commit is contained in:
parent
47388e6e56
commit
30b69242f8
@ -1371,6 +1371,7 @@ py_library(
|
||||
":_pywrap_py_exception_registry",
|
||||
":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
|
||||
":_pywrap_python_api_dispatcher",
|
||||
":_pywrap_python_api_info",
|
||||
":_pywrap_python_op_gen",
|
||||
":_pywrap_quantize_training",
|
||||
":_pywrap_stacktrace_handler",
|
||||
@ -1696,6 +1697,7 @@ cc_library(
|
||||
":cpp_python_util",
|
||||
":safe_pyobject_ptr",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -1739,6 +1741,86 @@ tf_py_test(
|
||||
tags = ["no_pip"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "python_api_info",
|
||||
srcs = ["framework/python_api_info.cc"],
|
||||
hdrs = ["framework/python_api_info.h"],
|
||||
deps = [
|
||||
":cpp_python_util",
|
||||
":op_def_util_cc",
|
||||
":python_tensor_converter",
|
||||
":safe_pyobject_ptr",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/python/eager:pywrap_tfe_lib",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
# Note: this target is only used by python_api_info_test.
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_python_api_info",
|
||||
srcs = ["framework/python_api_info_wrapper.cc"],
|
||||
hdrs = [
|
||||
"framework/op_def_util.h",
|
||||
"framework/python_api_info.h",
|
||||
"framework/python_tensor_converter.h",
|
||||
"lib/core/numpy.h",
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/c/experimental/ops:pywrap_required_hdrs",
|
||||
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/python/eager:pywrap_required_hdrs",
|
||||
],
|
||||
module_name = "_pywrap_python_api_info",
|
||||
deps = [
|
||||
":safe_pyobject_ptr_required_hdrs",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@pybind11",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/common_runtime:core_cpu_headers_lib",
|
||||
"//tensorflow/core:lib_headers_for_pybind",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//tensorflow/c:pywrap_required_hdrs",
|
||||
"@com_google_absl//absl/types:span",
|
||||
] + if_static(
|
||||
extra_deps = [
|
||||
"//tensorflow/core/protobuf:eager_service_proto_cc",
|
||||
"//tensorflow/core/protobuf:master_proto_cc",
|
||||
"//tensorflow/core/protobuf:worker_proto_cc",
|
||||
],
|
||||
otherwise = [
|
||||
"//tensorflow/core/protobuf:eager_service_proto_cc_headers_only",
|
||||
"//tensorflow/core/protobuf:master_proto_cc_headers_only",
|
||||
"//tensorflow/core/protobuf:worker_proto_cc_headers_only",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "python_api_info_test",
|
||||
srcs = ["framework/python_api_info_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":_pywrap_python_api_info",
|
||||
":_pywrap_python_tensor_converter",
|
||||
":client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "python_api_dispatcher",
|
||||
srcs = ["framework/python_api_dispatcher.cc"],
|
||||
@ -6109,6 +6191,7 @@ pywrap_tensorflow_macro(
|
||||
":pybind11_status",
|
||||
":pybind11_proto",
|
||||
":python_api_dispatcher",
|
||||
":python_api_info",
|
||||
":python_op_gen",
|
||||
":python_tensor_converter",
|
||||
":safe_pyobject_ptr",
|
||||
@ -6181,6 +6264,7 @@ filegroup(
|
||||
":py_exception_registry", # py_exception_registry
|
||||
":py_func_lib", # py_func
|
||||
":python_api_dispatcher", # python_api_dispatcher
|
||||
":python_api_info", # python_api_info
|
||||
":python_tensor_converter", # python_tensor_converter
|
||||
":python_op_gen", # python_op_gen
|
||||
":safe_ptr", # checkpoint_reader
|
||||
|
508
tensorflow/python/framework/python_api_info.cc
Normal file
508
tensorflow/python/framework/python_api_info.cc
Normal file
@ -0,0 +1,508 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/python/framework/python_api_info.h"
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/python/eager/pywrap_tensor.h"
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
#include "tensorflow/python/framework/op_def_util.h"
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
#include "tensorflow/python/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
// Python 2.x:
|
||||
#define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
|
||||
#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
|
||||
#define PY_STRING_FROMSTRING(x) (PyString_FromString(x))
|
||||
#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
|
||||
#define PY_STRING_AS_CSTR(x) (PyString_AsString(x))
|
||||
#else
|
||||
// Python 3.x:
|
||||
#define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
|
||||
#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
|
||||
#define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x))
|
||||
#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
|
||||
#define PY_STRING_AS_CSTR(x) (PyUnicode_AsUTF8AndSize((x), nullptr))
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// Converts the given object to an interned Python string, and returns its
|
||||
// data pointer. (This means we don't need to worry about ownership for
|
||||
// this string.)
|
||||
const char* InternPyString(const std::string& s) {
|
||||
Safe_PyObjectPtr interned(PY_STRING_INTERN_FROM_STRING(s.c_str()));
|
||||
return PY_STRING_AS_CSTR(interned.get());
|
||||
}
|
||||
|
||||
template <typename T, typename UnaryPredicate>
|
||||
void RemoveIf(UnaryPredicate p, std::vector<T>* vec) {
|
||||
vec->erase(std::remove_if(vec->begin(), vec->end(), p), vec->end());
|
||||
}
|
||||
|
||||
struct DataTypeFormatter {
|
||||
void operator()(std::string* out, DataType dtype) const {
|
||||
out->append(DataType_Name(dtype));
|
||||
}
|
||||
};
|
||||
|
||||
// Populates `param_names` and `defaults_tuple` based on the given OpDef.
|
||||
void GetOpDefNamesAndDefaults(const tensorflow::OpDef& op_def,
|
||||
std::vector<string>& param_names,
|
||||
Safe_PyObjectPtr& defaults_tuple) {
|
||||
param_names.reserve(op_def.input_arg_size() + op_def.attr_size());
|
||||
std::set<std::string> inferred_attrs;
|
||||
|
||||
// Input parameters come first, in the order they occur in the OpDef.
|
||||
for (const auto& input : op_def.input_arg()) {
|
||||
param_names.push_back(input.name());
|
||||
if (!input.type_attr().empty()) {
|
||||
inferred_attrs.insert(input.type_attr());
|
||||
}
|
||||
if (!input.type_list_attr().empty()) {
|
||||
inferred_attrs.insert(input.type_list_attr());
|
||||
}
|
||||
if (!input.number_attr().empty()) {
|
||||
inferred_attrs.insert(input.number_attr());
|
||||
}
|
||||
}
|
||||
|
||||
// Next come attribute params without defaults, followed by attributes with
|
||||
// defaults (but inferred attributes are not included).
|
||||
std::vector<std::string> param_names_with_default;
|
||||
std::vector<Safe_PyObjectPtr> defaults;
|
||||
for (const auto& attr : op_def.attr()) {
|
||||
if (inferred_attrs.count(attr.name()) == 0) {
|
||||
if (attr.has_default_value()) {
|
||||
param_names_with_default.push_back(attr.name());
|
||||
defaults.push_back(AttrValueToPyObject(attr.default_value()));
|
||||
} else {
|
||||
param_names.push_back(attr.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
param_names.insert(param_names.end(), param_names_with_default.begin(),
|
||||
param_names_with_default.end());
|
||||
|
||||
// Finally, the 'name' parameter comes at the end, and its default value
|
||||
// is the operation's name.
|
||||
param_names.push_back("name");
|
||||
defaults.emplace_back(PY_STRING_FROMSTRING(op_def.name().c_str()));
|
||||
|
||||
defaults_tuple.reset(PyTuple_New(defaults.size()));
|
||||
for (int i = 0; i < defaults.size(); ++i) {
|
||||
PyTuple_SET_ITEM(defaults_tuple.get(), i, defaults[i].release());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PythonAPIInfo::PythonAPIInfo(const std::string& api_name)
|
||||
: api_name_(InternPyString(api_name)) {}
|
||||
|
||||
Status PythonAPIInfo::Initialize(const OpDef& op_def,
|
||||
const std::vector<string> param_names,
|
||||
PyObject* defaults_tuple) {
|
||||
// Intern the parameter names.
|
||||
param_names_.reserve(param_names.size());
|
||||
for (const auto& param_name : param_names) {
|
||||
param_names_.push_back(InternPyString(param_name));
|
||||
}
|
||||
|
||||
Py_INCREF(defaults_tuple);
|
||||
defaults_tuple_.reset(defaults_tuple);
|
||||
|
||||
// Build an index to look up parameter index by name. (Does not include
|
||||
// inferred attributes.)
|
||||
std::map<std::string, int> param_name_to_index;
|
||||
for (int i = 0; i < param_names_.size(); ++i) {
|
||||
param_name_to_index[param_names_[i]] = i;
|
||||
}
|
||||
|
||||
// Initialize each attribute & input parameter.
|
||||
attributes_.reserve(op_def.attr_size());
|
||||
for (const auto& attr_def : op_def.attr()) {
|
||||
TF_RETURN_IF_ERROR(InitializeAttribute(attr_def, param_name_to_index));
|
||||
}
|
||||
|
||||
inputs_.reserve(op_def.input_arg_size());
|
||||
for (const auto& arg_def : op_def.input_arg()) {
|
||||
TF_RETURN_IF_ERROR(InitializeInput(arg_def, param_name_to_index));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(CheckParamNames());
|
||||
|
||||
// Filter out any unused entries from inputs_with_*_attrs_.
|
||||
RemoveIf(
|
||||
[](const InputsWithTypeAttr& input) {
|
||||
return input.tensor_params.empty() && input.tensor_list_params.empty();
|
||||
},
|
||||
&inputs_with_type_attrs_);
|
||||
RemoveIf(
|
||||
[](const InputsWithTypeListAttr& input) {
|
||||
return input.tensor_list_params.empty();
|
||||
},
|
||||
&inputs_with_type_list_attrs_);
|
||||
RemoveIf(
|
||||
[](const InputsWithNumberAttr& input) {
|
||||
return input.tensor_list_params.empty();
|
||||
},
|
||||
&inputs_with_number_attrs_);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonAPIInfo::CheckParamNames() const {
|
||||
std::vector<bool> param_found(param_names_.size());
|
||||
for (const auto& attr : attributes_) {
|
||||
if (attr.index != -1) {
|
||||
param_found[attr.index] = true;
|
||||
}
|
||||
}
|
||||
for (const auto& input : inputs_) {
|
||||
param_found[input.index] = true;
|
||||
}
|
||||
|
||||
for (int i = 0; i < param_names_.size(); ++i) {
|
||||
if (param_names_[i] == std::string("name")) {
|
||||
continue;
|
||||
}
|
||||
if (!param_found[i]) {
|
||||
return errors::InvalidArgument(
|
||||
api_name_, ": missing specification for parameter ", param_names_[i]);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonAPIInfo::InitializeFromRegisteredOp(const std::string& op_name) {
|
||||
const tensorflow::OpDef* op_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def));
|
||||
std::vector<std::string> param_names;
|
||||
Safe_PyObjectPtr defaults_tuple;
|
||||
GetOpDefNamesAndDefaults(*op_def, param_names, defaults_tuple);
|
||||
TF_RETURN_IF_ERROR(Initialize(*op_def, param_names, defaults_tuple.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonAPIInfo::InitializeFromParamSpecs(
|
||||
const std::map<std::string, std::string>& input_specs,
|
||||
const std::map<std::string, std::string>& attr_specs,
|
||||
const std::vector<string> param_names, PyObject* defaults_tuple) {
|
||||
OpDefBuilder op_def_builder(api_name_);
|
||||
op_def_builder.AllowAttrTypeAny();
|
||||
for (const auto& attr_spec : attr_specs) {
|
||||
op_def_builder.Attr(absl::StrCat(attr_spec.first, ": ", attr_spec.second));
|
||||
}
|
||||
for (const auto& input_spec : input_specs) {
|
||||
op_def_builder.Input(
|
||||
absl::StrCat(input_spec.first, ": ", input_spec.second));
|
||||
}
|
||||
OpRegistrationData op_reg_data;
|
||||
TF_RETURN_IF_ERROR(op_def_builder.Finalize(&op_reg_data));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
Initialize(op_reg_data.op_def, param_names, defaults_tuple));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonAPIInfo::InitializeAttribute(
|
||||
const OpDef::AttrDef& attr_def,
|
||||
const std::map<std::string, int>& param_name_to_index) {
|
||||
if (attr_def.name() == "name") {
|
||||
return errors::InvalidArgument(
|
||||
api_name_, ": Reserved parameter `name` was used as an attribute.");
|
||||
}
|
||||
const char* name = InternPyString(attr_def.name());
|
||||
|
||||
const int param_index =
|
||||
gtl::FindWithDefault(param_name_to_index, attr_def.name(), -1);
|
||||
const AttributeType dtype = AttributeTypeFromName(attr_def.type());
|
||||
const int inferred_index = -1;
|
||||
attributes_.push_back({param_index, dtype, name, inferred_index});
|
||||
Attribute& attr = attributes_.back();
|
||||
if (attr.type == AttributeType::UNKNOWN) {
|
||||
return errors::InvalidArgument(api_name_, ": Bad attribute type for ",
|
||||
attr_def.name(), ": '", attr_def.type(),
|
||||
"'");
|
||||
}
|
||||
std::vector<DataType>* ok_dtypes = nullptr;
|
||||
|
||||
if (attr.type == AttributeType::DTYPE) {
|
||||
DataType default_dtype = attr_def.has_default_value()
|
||||
? attr_def.default_value().type()
|
||||
: DT_INVALID;
|
||||
inputs_with_type_attrs_.push_back({&attr, default_dtype});
|
||||
ok_dtypes = &inputs_with_type_attrs_.back().ok_dtypes;
|
||||
|
||||
} else if (attr.type == AttributeType::LIST_DTYPE) {
|
||||
inputs_with_type_list_attrs_.push_back({&attr});
|
||||
for (int d : attr_def.default_value().list().type()) {
|
||||
inputs_with_type_list_attrs_.back().default_dtypes.push_back(
|
||||
static_cast<DataType>(d));
|
||||
}
|
||||
ok_dtypes = &inputs_with_type_list_attrs_.back().ok_dtypes;
|
||||
}
|
||||
|
||||
if (attr_def.has_allowed_values() && ok_dtypes) {
|
||||
const auto& dtypes = attr_def.allowed_values().list();
|
||||
for (int i = 0; i < dtypes.type_size(); ++i) {
|
||||
ok_dtypes->push_back(dtypes.type(i));
|
||||
}
|
||||
}
|
||||
|
||||
if (attr.type == AttributeType::INT) {
|
||||
int64 default_len =
|
||||
attr_def.has_default_value() ? attr_def.default_value().i() : -1;
|
||||
inputs_with_number_attrs_.push_back({&attr, default_len});
|
||||
}
|
||||
|
||||
// If this is an inferred attribute, then record its name and index.
|
||||
if (attr.index == -1) {
|
||||
std::vector<const char*>* inferred_attr_names =
|
||||
attr.type == AttributeType::DTYPE ? &inferred_type_attrs_
|
||||
: attr.type == AttributeType::LIST_DTYPE ? &inferred_type_list_attrs_
|
||||
: attr.type == AttributeType::INT ? &inferred_length_attrs_
|
||||
: nullptr;
|
||||
if (inferred_attr_names == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
api_name_, ": Missing specification for parameter ", attr_def.name());
|
||||
} else {
|
||||
attr.inferred_index = inferred_attr_names->size();
|
||||
inferred_attr_names->push_back(attr.name);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonAPIInfo::InitializeInput(
|
||||
const OpDef::ArgDef& arg_def,
|
||||
const std::map<std::string, ParamIndex>& param_name_to_index) {
|
||||
if (arg_def.name() == "name") {
|
||||
return errors::InvalidArgument(
|
||||
api_name_, ": Reserved parameter `name` was used as a tensor input.");
|
||||
}
|
||||
const ParamIndex param_index =
|
||||
gtl::FindWithDefault(param_name_to_index, arg_def.name(), -1);
|
||||
if (param_index == -1) {
|
||||
return errors::InvalidArgument(
|
||||
api_name_, ": Missing specification for parameter ", arg_def.name());
|
||||
}
|
||||
if (arg_def.is_ref()) {
|
||||
// TODO(b/164980194): Support reference parameters.
|
||||
// - Pass as_ref to convert_to_tensor
|
||||
// - Check that values for ref inputs have ref types.
|
||||
return errors::InvalidArgument(api_name_,
|
||||
": PythonAPIInfo doesn't support reference "
|
||||
"parameters yet.");
|
||||
}
|
||||
bool is_list =
|
||||
!arg_def.number_attr().empty() || !arg_def.type_list_attr().empty();
|
||||
inputs_.push_back({param_index, is_list});
|
||||
|
||||
if (!arg_def.type_list_attr().empty()) {
|
||||
// list(input) with dtypes specified by a `list(type)` attribute.
|
||||
InputsWithTypeListAttr* input =
|
||||
FindInputsWithTypeListAttr(arg_def.type_list_attr());
|
||||
if (!input) {
|
||||
return errors::InvalidArgument(
|
||||
api_name_, ": Type attribute ", arg_def.type_list_attr(),
|
||||
" for parameter ", arg_def.name(), " not found.");
|
||||
}
|
||||
input->tensor_list_params.push_back(param_index);
|
||||
} else if (!arg_def.type_attr().empty()) {
|
||||
InputsWithTypeAttr* input = FindInputsWithTypeAttr(arg_def.type_attr());
|
||||
// input or list(input) with dtype specified by a `type` attribute.
|
||||
if (!input) {
|
||||
return errors::InvalidArgument(api_name_, ": Type attribute ",
|
||||
arg_def.type_attr(), " for parameter ",
|
||||
arg_def.name(), " not found.");
|
||||
}
|
||||
if (arg_def.number_attr().empty()) {
|
||||
input->tensor_params.push_back(param_index);
|
||||
} else {
|
||||
input->tensor_list_params.push_back(param_index);
|
||||
}
|
||||
} else {
|
||||
// input or list(input) with fixed dtype
|
||||
inputs_with_fixed_dtype_.push_back({arg_def.type(), param_index, is_list});
|
||||
}
|
||||
|
||||
if (!arg_def.number_attr().empty()) {
|
||||
InputsWithNumberAttr* input =
|
||||
FindInputsWithNumberAttr(arg_def.number_attr());
|
||||
if (!input) {
|
||||
return errors::InvalidArgument(api_name_, ": Length attribute ",
|
||||
arg_def.number_attr(), " for parameter ",
|
||||
arg_def.name(), " not found.");
|
||||
}
|
||||
input->tensor_list_params.push_back(param_index);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PythonAPIInfo::InputsWithTypeAttr* PythonAPIInfo::FindInputsWithTypeAttr(
|
||||
const string& name) {
|
||||
for (auto& input : inputs_with_type_attrs_) {
|
||||
if (name == input.type_attr->name) {
|
||||
return &input;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PythonAPIInfo::InputsWithTypeListAttr*
|
||||
PythonAPIInfo::FindInputsWithTypeListAttr(const string& name) {
|
||||
for (auto& input : inputs_with_type_list_attrs_) {
|
||||
if (name == input.type_list_attr->name) {
|
||||
return &input;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PythonAPIInfo::InputsWithNumberAttr* PythonAPIInfo::FindInputsWithNumberAttr(
|
||||
const string& name) {
|
||||
for (auto& input : inputs_with_number_attrs_) {
|
||||
if (name == input.number_attr->name) {
|
||||
return &input;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
string PythonAPIInfo::DebugInfo() const {
|
||||
string s = absl::StrCat("DebugInfo for ", api_name_, ":\n");
|
||||
absl::StrAppend(&s, " param_names=[", absl::StrJoin(param_names_, ", "),
|
||||
"]\n");
|
||||
Safe_PyObjectPtr defaults_repr(PyObject_Repr(defaults_tuple_.get()));
|
||||
absl::StrAppend(
|
||||
&s, " defaults_tuple=", TFE_GetPythonString(defaults_repr.get()), "\n");
|
||||
if (!attributes_.empty()) {
|
||||
absl::StrAppend(&s, " attributes=[");
|
||||
for (const auto& attrib : attributes_) {
|
||||
if (attrib.index != -1) {
|
||||
absl::StrAppend(&s, "\n {index=", attrib.index);
|
||||
DCHECK_EQ(attrib.inferred_index, -1);
|
||||
} else {
|
||||
absl::StrAppend(&s, "\n {inferred_index=", attrib.inferred_index);
|
||||
}
|
||||
absl::StrAppend(&s, ", name=", attrib.name,
|
||||
", type=", AttributeTypeToName(attrib.type), "},");
|
||||
}
|
||||
absl::StrAppend(&s, "]\n");
|
||||
}
|
||||
if (!inputs_.empty()) {
|
||||
absl::StrAppend(&s, " inputs=[");
|
||||
for (const auto& input : inputs_) {
|
||||
absl::StrAppend(&s, "\n {index=", input.index,
|
||||
", name=", param_names_[input.index],
|
||||
", is_list=", input.is_list, "},");
|
||||
}
|
||||
absl::StrAppend(&s, "]\n");
|
||||
}
|
||||
if (!inputs_with_fixed_dtype_.empty()) {
|
||||
absl::StrAppend(&s, " inputs_with_fixed_dtype=[");
|
||||
for (const auto& input : inputs_with_fixed_dtype_) {
|
||||
absl::StrAppend(&s, "\n {index=", input.index,
|
||||
", dtype=", DataType_Name(input.dtype),
|
||||
", is_list=", input.is_list, "},");
|
||||
}
|
||||
absl::StrAppend(&s, "]\n");
|
||||
}
|
||||
if (!inputs_with_type_attrs_.empty()) {
|
||||
absl::StrAppend(&s, " inputs_with_type_attr=[");
|
||||
for (const auto& input : inputs_with_type_attrs_) {
|
||||
absl::StrAppend(&s, "\n {type_attr=", input.type_attr->name);
|
||||
if (input.default_dtype != DT_INVALID) {
|
||||
absl::StrAppend(&s,
|
||||
", default_dtype=", DataType_Name(input.default_dtype));
|
||||
}
|
||||
if (!input.tensor_params.empty()) {
|
||||
absl::StrAppend(&s, ", tensor_params=[",
|
||||
absl::StrJoin(input.tensor_params, ", "), "]");
|
||||
}
|
||||
if (!input.tensor_list_params.empty()) {
|
||||
absl::StrAppend(&s, ", tensor_list_params=[",
|
||||
absl::StrJoin(input.tensor_list_params, ", "), "]");
|
||||
}
|
||||
if (!input.ok_dtypes.empty()) {
|
||||
absl::StrAppend(
|
||||
&s, ", ok_dtypes=[",
|
||||
absl::StrJoin(input.ok_dtypes, ", ", DataTypeFormatter()), "]");
|
||||
}
|
||||
absl::StrAppend(&s, "},");
|
||||
}
|
||||
absl::StrAppend(&s, "]\n");
|
||||
}
|
||||
if (!inputs_with_type_list_attrs_.empty()) {
|
||||
absl::StrAppend(&s, " inputs_with_type_list_attrs=[");
|
||||
for (const auto& input : inputs_with_type_list_attrs_) {
|
||||
absl::StrAppend(&s, "\n {type_list_attr=", input.type_list_attr->name);
|
||||
if (!input.default_dtypes.empty()) {
|
||||
absl::StrAppend(
|
||||
&s, ", default_dtypes=[",
|
||||
absl::StrJoin(input.default_dtypes, ", ", DataTypeFormatter()),
|
||||
"]");
|
||||
}
|
||||
if (!input.tensor_list_params.empty()) {
|
||||
absl::StrAppend(&s, ", tensor_list_params=[",
|
||||
absl::StrJoin(input.tensor_list_params, ", "), "]");
|
||||
}
|
||||
if (!input.ok_dtypes.empty()) {
|
||||
absl::StrAppend(
|
||||
&s, ", ok_dtypes=[",
|
||||
absl::StrJoin(input.ok_dtypes, ", ", DataTypeFormatter()), "]");
|
||||
}
|
||||
absl::StrAppend(&s, "},");
|
||||
}
|
||||
absl::StrAppend(&s, "]\n");
|
||||
}
|
||||
if (!inputs_with_number_attrs_.empty()) {
|
||||
absl::StrAppend(&s, " inputs_with_number_attrs=[");
|
||||
for (const auto& input : inputs_with_number_attrs_) {
|
||||
absl::StrAppend(&s, "\n {number_attr=", input.number_attr->name,
|
||||
", default_length=", input.default_length,
|
||||
", tensor_list_params=[",
|
||||
absl::StrJoin(input.tensor_list_params, ", "), "],\n");
|
||||
}
|
||||
absl::StrAppend(&s, "]\n");
|
||||
}
|
||||
if (!inferred_type_attrs_.empty()) {
|
||||
absl::StrAppend(&s, " inferred_type_attrs=[",
|
||||
absl::StrJoin(inferred_type_attrs_, ", "), "]\n");
|
||||
}
|
||||
if (!inferred_type_list_attrs_.empty()) {
|
||||
absl::StrAppend(&s, " inferred_type_list_attrs=[",
|
||||
absl::StrJoin(inferred_type_list_attrs_, ", "), "]\n");
|
||||
}
|
||||
if (!inferred_length_attrs_.empty()) {
|
||||
absl::StrAppend(&s, " inferred_length_attrs=[",
|
||||
absl::StrJoin(inferred_length_attrs_, ", "), "]\n");
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
298
tensorflow/python/framework/python_api_info.h
Normal file
298
tensorflow/python/framework/python_api_info.h
Normal file
@ -0,0 +1,298 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_
|
||||
#define TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/python/framework/op_def_util.h"
|
||||
#include "tensorflow/python/framework/python_tensor_converter.h"
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Precomputed information about a TensorFlow Python API.
|
||||
//
|
||||
// PythonAPIInfo records information about a single TensorFlow Python API,
|
||||
// in order to allow calls to the API to be executed more efficiently. This
|
||||
// information includes:
|
||||
//
|
||||
// * The name of the API. (E.g. "tf.math.add")
|
||||
//
|
||||
// * The name of the registered op that implements the API, if applicable
|
||||
// (e.g. "AddV2").
|
||||
//
|
||||
// * Information about the API's parameters. Parameters are divided into two
|
||||
// "kinds": inputs and attributes. An *input* is a parameter that
|
||||
// expects a Tensor or list of Tensors, and it is described by an `ArgDef`.
|
||||
// An *attribute* is a parameter that expects any other value type, and it is
|
||||
// described by an `AttrDef`.
|
||||
//
|
||||
// * Default values for the API's attribute parameters.
|
||||
//
|
||||
// * Information about "inferred attributes" -- attributes whose values are
|
||||
// inferred from `input` parameters. There are two kinds of inferred
|
||||
// attributes: Tensor dtypes, which are inferred from tensor and list(tensor)
|
||||
// parameters; and list lengths, which are inferred from list(tensor)
|
||||
// parameters.
|
||||
class PythonAPIInfo {
|
||||
public:
|
||||
// The index of a parameter in the canonicalized parameter list. The
|
||||
// canonicalized parameter list includes inputs and attributes (but does
|
||||
// not include inferred attributes). `-1` is used for inferred attributes.
|
||||
using ParamIndex = int;
|
||||
|
||||
// Information about a parameter that expects a non-Tensor value.
|
||||
struct Attribute {
|
||||
ParamIndex index; // -1 if this is an inferred attribute
|
||||
AttributeType type;
|
||||
const char* name; // Interned python string
|
||||
int inferred_index; // index to store attribute in InferredAttributes
|
||||
};
|
||||
|
||||
// Information about a parameter that expects a Tensor or list(Tensor).
|
||||
// Additional information about tensor parameters is stored in types
|
||||
// defined below, in order to simplify dtype/length inference:
|
||||
// * FixedDTypeInput: inputs with fixed dtypes.
|
||||
// * InputsWithTypeAttr: groups inputs that use a type_attr for dtype.
|
||||
// * InputsWithTypeListAttr: groups inputs that use a type_list_attr.
|
||||
// * InputsWithNumberAttr: groups inputs by a number_attr for length.
|
||||
struct Input {
|
||||
ParamIndex index;
|
||||
bool is_list;
|
||||
};
|
||||
|
||||
// Information about a Tensor parameter w/ fixed dtype.
|
||||
struct InputWithFixedDType {
|
||||
DataType dtype;
|
||||
ParamIndex index;
|
||||
bool is_list;
|
||||
};
|
||||
|
||||
// Information about Tensor parameters whose DType is specified by a single
|
||||
// `type_attr` attribute.
|
||||
struct InputsWithTypeAttr {
|
||||
Attribute* type_attr; // not owned.
|
||||
DataType default_dtype; // DT_INVALID if no default.
|
||||
std::vector<ParamIndex> tensor_params; // single-tensor inputs.
|
||||
std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs.
|
||||
std::vector<DataType> ok_dtypes;
|
||||
};
|
||||
|
||||
// Information about Tensor parameters whose DType is specified by a single
|
||||
// `type_list_attr` attribute.
|
||||
struct InputsWithTypeListAttr {
|
||||
Attribute* type_list_attr; // not owned.
|
||||
std::vector<DataType> default_dtypes; // empty if no default.
|
||||
std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs.
|
||||
std::vector<DataType> ok_dtypes;
|
||||
};
|
||||
|
||||
// Information about Tensor-list parameters whose length is specified by a
|
||||
// single `int` attribute.
|
||||
struct InputsWithNumberAttr {
|
||||
Attribute* number_attr; // not owned.
|
||||
int64 default_length; // -1 for no default.
|
||||
std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs.
|
||||
};
|
||||
|
||||
// Structure used to return inferred attribute values.
|
||||
// * types[i] is the inferred value for inferred_type_attrs()[i]
|
||||
// * type_lists[i] is the inferred value for inferred_type_list_attrs()[i]
|
||||
// * lengths[i] is the inferred value for inferred_length_attrs()[i]
|
||||
struct InferredAttributes {
|
||||
std::vector<DataType> types;
|
||||
std::vector<std::vector<DataType>> type_lists;
|
||||
std::vector<int64> lengths;
|
||||
};
|
||||
|
||||
// Constructs a new PythonAPIInfo.
|
||||
//
|
||||
// Note: One of the `Initialize()` functions must be called before the
|
||||
// `PythonAPIInfo` is used.
|
||||
//
|
||||
// Args:
|
||||
// api_name: The fully-qualified name of the python API (e.g., tf.math.sum).
|
||||
explicit PythonAPIInfo(const std::string& api_name);
|
||||
|
||||
// Initializes this PythonAPIInfo.
|
||||
//
|
||||
// Args:
|
||||
// op_def: Contains information about the parameters.
|
||||
// param_names: The argument names for the python API, in canonical order.
|
||||
// defaults_tuple: Tuple containing default values for the parameters,
|
||||
// right-aligned with `param_names` -- i.e., `defaults[-i]` is the default
|
||||
// for `param_names[-i]`.
|
||||
Status Initialize(const OpDef& op_def, const std::vector<string> param_names,
|
||||
PyObject* defaults_tuple);
|
||||
|
||||
// Initialize this PythonAPIInfo based on the registered OpDef for the given
|
||||
// operation.
|
||||
//
|
||||
// Args:
|
||||
// op_name: The registered name of the operation (e.g. "AddV2").
|
||||
Status InitializeFromRegisteredOp(const std::string& op_name);
|
||||
|
||||
// Initializes this PythonAPIInfo based on a set of parameter specifications.
|
||||
//
|
||||
// Args:
|
||||
// input_specs: Mapping from parameter name to specification string for
|
||||
// each input (parameter that expects a tensor value).
|
||||
// attr_specs: Mapping from parameter name to specification string for
|
||||
// each attribute (parameter that expects a non-tensor value).
|
||||
// param_names: The argument names for the python API, in canonical order.
|
||||
// defaults_tuple: Tuple containing default values for the parameters,
|
||||
// right-aligned with `param_names` -- i.e., `defaults[-i]` is the default
|
||||
// for `param_names[-i]`.
|
||||
//
|
||||
// Note: the `name` parameter should not be included in `input_specs` or
|
||||
// `attr_specs`.
|
||||
Status InitializeFromParamSpecs(
|
||||
const std::map<std::string, std::string>& input_specs,
|
||||
const std::map<std::string, std::string>& attr_specs,
|
||||
const std::vector<string> param_names, PyObject* defaults_tuple);
|
||||
|
||||
// The name of the API that is described by this PythonAPIInfo.
|
||||
const char* api_name() const { return api_name_; }
|
||||
|
||||
// The ordered names of the canononical parameters that this API expects.
|
||||
const std::vector<const char*>& param_names() const { return param_names_; }
|
||||
|
||||
// A Python tuple containing the default values for parameters. This is
|
||||
// right-aligned with `param_name` -- i.e., `defaults[-i]` is the default
|
||||
// for `param_names[-i]`.
|
||||
const PyObject* defaults_tuple() const { return defaults_tuple_.get(); }
|
||||
|
||||
// Information about the attribute (non-tensor) parameters for this API.
|
||||
const std::vector<Attribute>& attributes() const { return attributes_; }
|
||||
|
||||
// Information about the input (tensor) parameters for this API.
|
||||
const std::vector<Input>& inputs() const { return inputs_; }
|
||||
const std::vector<InputWithFixedDType>& inputs_with_fixed_dtype() const {
|
||||
return inputs_with_fixed_dtype_;
|
||||
}
|
||||
const std::vector<InputsWithTypeAttr>& inputs_with_type_attrs() const {
|
||||
return inputs_with_type_attrs_;
|
||||
}
|
||||
const std::vector<InputsWithTypeListAttr>& inputs_with_type_list_attrs()
|
||||
const {
|
||||
return inputs_with_type_list_attrs_;
|
||||
}
|
||||
const std::vector<InputsWithNumberAttr>& inputs_with_number_attrs() const {
|
||||
return inputs_with_number_attrs_;
|
||||
}
|
||||
|
||||
// Names of inferred attributes.
|
||||
const std::vector<const char*>& inferred_type_attrs() const {
|
||||
return inferred_type_attrs_;
|
||||
}
|
||||
const std::vector<const char*>& inferred_type_list_attrs() const {
|
||||
return inferred_type_list_attrs_;
|
||||
}
|
||||
const std::vector<const char*>& inferred_length_attrs() const {
|
||||
return inferred_length_attrs_;
|
||||
}
|
||||
|
||||
// Returns a string summarizing the internal state of this type converter.
|
||||
string DebugInfo() const;
|
||||
|
||||
private:
|
||||
// Adds an entry to the attributes_ vector based on the given `AttrDef`.
|
||||
//
|
||||
// If `attr_def` describes a type attribute, then adds a value to
|
||||
// inputs_with_type_attrs_ or inputs_with_type_list_attrs_ (to record any
|
||||
// tensor inputs that use this dtype).
|
||||
//
|
||||
// If `attr_def` describes an int attribute, then adds a value to
|
||||
// inputs_with_number_attrs_ (to record any tensor inputs that use this
|
||||
// value as a list length).
|
||||
Status InitializeAttribute(
|
||||
const OpDef::AttrDef& attr_def,
|
||||
const std::map<std::string, ParamIndex>& param_name_to_index);
|
||||
|
||||
// Adds an entry to the inputs_ vector based on the given `ArgDef`.
|
||||
//
|
||||
// If `arg_def` has a fixed dtype, then adds a value to `fixed_dtype_inputs`.
|
||||
//
|
||||
// If `arg_def`'s dtype is described by a `type` attr, then updates the
|
||||
// appropriate value in `inputs_with_type_attrs_` with information about the
|
||||
// `arg_def`.
|
||||
//
|
||||
// If `arg_def`'s dtype is described by a `list(type)` attr, then updates the
|
||||
// appropriate value in `inputs_with_type_list_attrs_` with information about
|
||||
// the `arg_def`.
|
||||
Status InitializeInput(const OpDef::ArgDef& arg_def,
|
||||
const std::map<std::string, int>& param_name_to_index);
|
||||
|
||||
// Checks that the OpDef used to initialize this PythonAPIInfo
|
||||
// had an AttrDef or ArgDef specification for each parameter.
|
||||
Status CheckParamNames() const;
|
||||
|
||||
// Searches inputs_with_type_attrs_ for an input with the given name.
|
||||
InputsWithTypeAttr* FindInputsWithTypeAttr(const string& name);
|
||||
|
||||
// Searches inputs_with_type_list_attrs_ for an input with the given name.
|
||||
InputsWithTypeListAttr* FindInputsWithTypeListAttr(const string& name);
|
||||
|
||||
// Searches inputs_with_type_list_attrs_ for an input with the given name.
|
||||
InputsWithNumberAttr* FindInputsWithNumberAttr(const string& name);
|
||||
|
||||
ABSL_MUST_USE_RESULT
|
||||
bool InferLengthAttributes(const absl::Span<PyObject*> params,
|
||||
std::vector<int64>& inferred_length_attrs) const;
|
||||
|
||||
// ==========================================================================
|
||||
// Member Variables
|
||||
// ==========================================================================
|
||||
|
||||
// The name of the API that is described by this PythonAPIInfo.
|
||||
// (Interned python string).
|
||||
const char* api_name_;
|
||||
|
||||
// The names of the parameters that this API expects.
|
||||
// (Interned python strings.)
|
||||
std::vector<const char*> param_names_;
|
||||
|
||||
// Tuple containing default values for the parameters, right-aligned with
|
||||
// `param_names` -- i.e., `defaults[-i]` is the default for `param_names[-i]`.
|
||||
Safe_PyObjectPtr defaults_tuple_;
|
||||
|
||||
// Information about the non-tensor-valued parameters that this API expects.
|
||||
std::vector<Attribute> attributes_;
|
||||
|
||||
// Information about the tensor-valued parameters that this API expects.
|
||||
std::vector<Input> inputs_;
|
||||
std::vector<InputWithFixedDType> inputs_with_fixed_dtype_;
|
||||
std::vector<InputsWithTypeAttr> inputs_with_type_attrs_;
|
||||
std::vector<InputsWithTypeListAttr> inputs_with_type_list_attrs_;
|
||||
std::vector<InputsWithNumberAttr> inputs_with_number_attrs_;
|
||||
|
||||
// Names of inferred attributes. (Interned python strings.)
|
||||
std::vector<const char*> inferred_type_attrs_;
|
||||
std::vector<const char*> inferred_type_list_attrs_;
|
||||
std::vector<const char*> inferred_length_attrs_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_
|
254
tensorflow/python/framework/python_api_info_test.py
Normal file
254
tensorflow/python/framework/python_api_info_test.py
Normal file
@ -0,0 +1,254 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.python.framework.python_api_info."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import _pywrap_python_api_info
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
|
||||
|
||||
# Helper function to make expected output in examples more compact:
|
||||
def Const(x):
|
||||
return constant_op.constant(x)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PythonAPIInfoTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
context.ensure_initialized()
|
||||
super(PythonAPIInfoTest, self).setUp()
|
||||
|
||||
def makeConverterForGenOp(self, op_name):
|
||||
"""Returns a PythonAPIInfo for the given gen_op."""
|
||||
api_info = _pywrap_python_api_info.PythonAPIInfo(op_name)
|
||||
api_info.InitializeFromRegisteredOp(op_name)
|
||||
return api_info
|
||||
|
||||
def makeConverterFromParamSpecs(self,
|
||||
api_name,
|
||||
param_names,
|
||||
input_specs,
|
||||
attr_specs,
|
||||
defaults=()):
|
||||
"""Returns a PythonAPIInfo built from the given specs."""
|
||||
api_info = _pywrap_python_api_info.PythonAPIInfo(api_name)
|
||||
api_info.InitializeFromParamSpecs(input_specs, attr_specs, param_names,
|
||||
defaults)
|
||||
return api_info
|
||||
|
||||
# This test initializes a PythonAPIInfo from a registered
|
||||
# op, and then uses DebugInfo() to check that the internal state is
|
||||
# correct.
|
||||
@parameterized.named_parameters([
|
||||
# An op whose inputs have fixed dtypes.
|
||||
("RegexFullMatch", "RegexFullMatch", "DebugInfo for RegexFullMatch:\n"
|
||||
" param_names=[input, pattern, name]\n"
|
||||
" defaults_tuple=('RegexFullMatch',)\n"
|
||||
" inputs=[\n"
|
||||
" {index=0, name=input, is_list=0},\n"
|
||||
" {index=1, name=pattern, is_list=0},]\n"
|
||||
" inputs_with_fixed_dtype=[\n"
|
||||
" {index=0, dtype=DT_STRING, is_list=0},\n"
|
||||
" {index=1, dtype=DT_STRING, is_list=0},]\n"),
|
||||
# An op whose input has a variable dtype.
|
||||
("Abs", "Abs", "DebugInfo for Abs:\n"
|
||||
" param_names=[x, name]\n"
|
||||
" defaults_tuple=('Abs',)\n"
|
||||
" attributes=[\n"
|
||||
" {inferred_index=0, name=T, type=type},]\n"
|
||||
" inputs=[\n"
|
||||
" {index=0, name=x, is_list=0},]\n"
|
||||
" inputs_with_type_attr=[\n"
|
||||
" {type_attr=T, tensor_params=[0], ok_dtypes=[DT_BFLOAT16, DT_HALF, "
|
||||
"DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64]},]\n"
|
||||
" inferred_type_attrs=[T]\n"),
|
||||
# An op with two inputs that have the same (variable) dtype.
|
||||
("AddV2", "AddV2", "DebugInfo for AddV2:\n"
|
||||
" param_names=[x, y, name]\n"
|
||||
" defaults_tuple=('AddV2',)\n"
|
||||
" attributes=[\n"
|
||||
" {inferred_index=0, name=T, type=type},]\n"
|
||||
" inputs=[\n"
|
||||
" {index=0, name=x, is_list=0},\n"
|
||||
" {index=1, name=y, is_list=0},]\n"
|
||||
" inputs_with_type_attr=[\n"
|
||||
" {type_attr=T, tensor_params=[0, 1], ok_dtypes=[DT_BFLOAT16, "
|
||||
"DT_HALF, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_INT8, DT_INT16, DT_UINT32, "
|
||||
"DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128]},]\n"
|
||||
" inferred_type_attrs=[T]\n"),
|
||||
# An op with an int attribute.
|
||||
("GatherV2", "GatherV2", "DebugInfo for GatherV2:\n"
|
||||
" param_names=[params, indices, axis, batch_dims, name]\n"
|
||||
" defaults_tuple=(0, 'GatherV2')\n"
|
||||
" attributes=[\n"
|
||||
" {index=3, name=batch_dims, type=int},\n"
|
||||
" {inferred_index=0, name=Tparams, type=type},\n"
|
||||
" {inferred_index=1, name=Tindices, type=type},\n"
|
||||
" {inferred_index=2, name=Taxis, type=type},]\n"
|
||||
" inputs=[\n"
|
||||
" {index=0, name=params, is_list=0},\n"
|
||||
" {index=1, name=indices, is_list=0},\n"
|
||||
" {index=2, name=axis, is_list=0},]\n"
|
||||
" inputs_with_type_attr=[\n"
|
||||
" {type_attr=Tparams, tensor_params=[0]},\n"
|
||||
" {type_attr=Tindices, tensor_params=[1], "
|
||||
"ok_dtypes=[DT_INT32, DT_INT64]},\n"
|
||||
" {type_attr=Taxis, tensor_params=[2], "
|
||||
"ok_dtypes=[DT_INT32, DT_INT64]},]\n"
|
||||
" inferred_type_attrs=[Tparams, Tindices, Taxis]\n"),
|
||||
# An op with default attrib values.
|
||||
("ReduceJoin", "ReduceJoin", "DebugInfo for ReduceJoin:\n"
|
||||
" param_names=[inputs, reduction_indices, keep_dims, separator, name]\n"
|
||||
" defaults_tuple=(False, '', 'ReduceJoin')\n"
|
||||
" attributes=[\n"
|
||||
" {index=2, name=keep_dims, type=bool},\n"
|
||||
" {index=3, name=separator, type=string},]\n"
|
||||
" inputs=[\n"
|
||||
" {index=0, name=inputs, is_list=0},\n"
|
||||
" {index=1, name=reduction_indices, is_list=0},]\n"
|
||||
" inputs_with_fixed_dtype=[\n"
|
||||
" {index=0, dtype=DT_STRING, is_list=0},\n"
|
||||
" {index=1, dtype=DT_INT32, is_list=0},]\n"),
|
||||
# An op with a variable-dtype list input, and an int attribute.
|
||||
("ParseExampleV2", "ParseExampleV2", "DebugInfo for ParseExampleV2:\n"
|
||||
" param_names=[serialized, names, sparse_keys, dense_keys, "
|
||||
"ragged_keys, dense_defaults, num_sparse, sparse_types, "
|
||||
"ragged_value_types, ragged_split_types, dense_shapes, name]\n"
|
||||
" defaults_tuple=('ParseExampleV2',)\n"
|
||||
" attributes=[\n"
|
||||
" {inferred_index=0, name=Tdense, type=list(type)},\n"
|
||||
" {index=6, name=num_sparse, type=int},\n"
|
||||
" {index=7, name=sparse_types, type=list(type)},\n"
|
||||
" {index=8, name=ragged_value_types, type=list(type)},\n"
|
||||
" {index=9, name=ragged_split_types, type=list(type)},\n"
|
||||
" {index=10, name=dense_shapes, type=list(shape)},]\n"
|
||||
" inputs=[\n"
|
||||
" {index=0, name=serialized, is_list=0},\n"
|
||||
" {index=1, name=names, is_list=0},\n"
|
||||
" {index=2, name=sparse_keys, is_list=0},\n"
|
||||
" {index=3, name=dense_keys, is_list=0},\n"
|
||||
" {index=4, name=ragged_keys, is_list=0},\n"
|
||||
" {index=5, name=dense_defaults, is_list=1},]\n"
|
||||
" inputs_with_fixed_dtype=[\n"
|
||||
" {index=0, dtype=DT_STRING, is_list=0},\n"
|
||||
" {index=1, dtype=DT_STRING, is_list=0},\n"
|
||||
" {index=2, dtype=DT_STRING, is_list=0},\n"
|
||||
" {index=3, dtype=DT_STRING, is_list=0},\n"
|
||||
" {index=4, dtype=DT_STRING, is_list=0},]\n"
|
||||
" inputs_with_type_list_attrs=[\n"
|
||||
" {type_list_attr=Tdense, tensor_list_params=[5], "
|
||||
"ok_dtypes=[DT_FLOAT, DT_INT64, DT_STRING]},]\n"
|
||||
" inferred_type_list_attrs=[Tdense]\n"),
|
||||
# An op with a default dtype
|
||||
("BroadcastArgs", "BroadcastArgs", "DebugInfo for BroadcastArgs:\n"
|
||||
" param_names=[s0, s1, name]\n"
|
||||
" defaults_tuple=('BroadcastArgs',)\n"
|
||||
" attributes=[\n"
|
||||
" {inferred_index=0, name=T, type=type},]\n"
|
||||
" inputs=[\n"
|
||||
" {index=0, name=s0, is_list=0},\n"
|
||||
" {index=1, name=s1, is_list=0},]\n"
|
||||
" inputs_with_type_attr=[\n"
|
||||
" {type_attr=T, default_dtype=DT_INT32, tensor_params=[0, 1], "
|
||||
"ok_dtypes=[DT_INT32, DT_INT64]},]\n"
|
||||
" inferred_type_attrs=[T]\n"),
|
||||
])
|
||||
def testInitializeFromRegisteredOp(self, op_name, debug_info):
|
||||
api_info = self.makeConverterForGenOp(op_name)
|
||||
self.assertEqual(api_info.DebugInfo().strip(), debug_info.strip())
|
||||
|
||||
# This test initializes a PythonAPIInfo from parameter specs,
|
||||
# and then uses DebugInfo() to check that the internal state is correct.
|
||||
@parameterized.named_parameters([
|
||||
("NoParams", "NoParams", [], {}, {}, "DebugInfo for NoParams:\n"
|
||||
" param_names=[]\n"
|
||||
" defaults_tuple=()\n"),
|
||||
("OnlyNameParam", "OnlyNameParam", ["name"], {}, {},
|
||||
"DebugInfo for OnlyNameParam:\n"
|
||||
" param_names=[name]\n"
|
||||
" defaults_tuple=()\n"),
|
||||
("SomeBinaryOp", "SomeBinaryOp", ["x", "y"], dict(x="T", y="T"),
|
||||
dict(T="type"), "DebugInfo for SomeBinaryOp:\n"
|
||||
" param_names=[x, y]\n"
|
||||
" defaults_tuple=()\n"
|
||||
" attributes=[\n"
|
||||
" {inferred_index=0, name=T, type=type},]\n"
|
||||
" inputs=[\n"
|
||||
" {index=0, name=x, is_list=0},\n"
|
||||
" {index=1, name=y, is_list=0},]\n"
|
||||
" inputs_with_type_attr=[\n"
|
||||
" {type_attr=T, tensor_params=[0, 1]},]\n"
|
||||
" inferred_type_attrs=[T]\n"),
|
||||
("AllAttributeTypes", "AllAttributeTypes", [
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n",
|
||||
"o", "p"
|
||||
], {},
|
||||
dict(
|
||||
a="any",
|
||||
b="float",
|
||||
c="int",
|
||||
d="string",
|
||||
e="bool",
|
||||
f="type",
|
||||
g="shape",
|
||||
h="tensor",
|
||||
i="list(any)",
|
||||
j="list(float)",
|
||||
k="list(int)",
|
||||
l="list(string)",
|
||||
m="list(bool)",
|
||||
n="list(type)",
|
||||
o="list(shape)",
|
||||
p="list(tensor)"), "DebugInfo for AllAttributeTypes:\n"
|
||||
" param_names=[a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p]\n"
|
||||
" defaults_tuple=()\n"
|
||||
" attributes=[\n"
|
||||
" {index=0, name=a, type=any},\n"
|
||||
" {index=1, name=b, type=float},\n"
|
||||
" {index=2, name=c, type=int},\n"
|
||||
" {index=3, name=d, type=string},\n"
|
||||
" {index=4, name=e, type=bool},\n"
|
||||
" {index=5, name=f, type=type},\n"
|
||||
" {index=6, name=g, type=shape},\n"
|
||||
" {index=7, name=h, type=tensor},\n"
|
||||
" {index=8, name=i, type=list(any)},\n"
|
||||
" {index=9, name=j, type=list(float)},\n"
|
||||
" {index=10, name=k, type=list(int)},\n"
|
||||
" {index=11, name=l, type=list(string)},\n"
|
||||
" {index=12, name=m, type=list(bool)},\n"
|
||||
" {index=13, name=n, type=list(type)},\n"
|
||||
" {index=14, name=o, type=list(shape)},\n"
|
||||
" {index=15, name=p, type=list(tensor)},]\n"),
|
||||
])
|
||||
def testInitializeFromParamSpecs(self, api_name, param_names, input_specs,
|
||||
attr_specs, debug_info):
|
||||
api_info = self.makeConverterFromParamSpecs(api_name, param_names,
|
||||
input_specs, attr_specs)
|
||||
self.assertEqual(api_info.DebugInfo().strip(), debug_info.strip())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
75
tensorflow/python/framework/python_api_info_wrapper.cc
Normal file
75
tensorflow/python/framework/python_api_info_wrapper.cc
Normal file
@ -0,0 +1,75 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Note: This library is only used by python_api_info_test. It
|
||||
// is not meant to be used in other circumstances.
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/python/framework/python_api_info.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void InitializeFromRegisteredOp(PythonAPIInfo* api_info,
|
||||
const std::string& op_name) {
|
||||
auto result = api_info->InitializeFromRegisteredOp(op_name);
|
||||
if (!result.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, result.ToString().c_str());
|
||||
throw py::error_already_set();
|
||||
}
|
||||
}
|
||||
|
||||
void InitializeFromParamSpecs(
|
||||
PythonAPIInfo* api_info,
|
||||
const std::map<std::string, std::string>& input_specs,
|
||||
const std::map<std::string, std::string>& attr_specs,
|
||||
const std::vector<string>& param_names, py::handle defaults_tuple) {
|
||||
auto result = api_info->InitializeFromParamSpecs(
|
||||
input_specs, attr_specs, param_names, defaults_tuple.ptr());
|
||||
if (!result.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, result.ToString().c_str());
|
||||
throw py::error_already_set();
|
||||
}
|
||||
}
|
||||
|
||||
std::string DebugInfo(PythonAPIInfo* api_info) { return api_info->DebugInfo(); }
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
using PythonAPIInfo = tensorflow::PythonAPIInfo;
|
||||
using InferredAttributes = tensorflow::PythonAPIInfo::InferredAttributes;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_python_api_info, m) {
|
||||
py::class_<PythonAPIInfo>(m, "PythonAPIInfo")
|
||||
.def(py::init<const std::string&>())
|
||||
.def("InitializeFromRegisteredOp",
|
||||
&tensorflow::InitializeFromRegisteredOp)
|
||||
.def("InitializeFromParamSpecs", &tensorflow::InitializeFromParamSpecs)
|
||||
.def("DebugInfo", &tensorflow::DebugInfo)
|
||||
.def("InferredTypeAttrs",
|
||||
[](PythonAPIInfo* self) { return self->inferred_type_attrs(); })
|
||||
.def("InferredTypeListAttrs",
|
||||
[](PythonAPIInfo* self) { return self->inferred_type_list_attrs(); })
|
||||
.def("InferredLengthAttrs",
|
||||
[](PythonAPIInfo* self) { return self->inferred_length_attrs(); });
|
||||
py::class_<InferredAttributes>(m, "InferredAttributes")
|
||||
.def_readonly("types", &InferredAttributes::types)
|
||||
.def_readonly("type_lists", &InferredAttributes::type_lists)
|
||||
.def_readonly("lengths", &InferredAttributes::lengths);
|
||||
}
|
@ -404,3 +404,6 @@ tensorflow::PythonAPIDispatcher
|
||||
|
||||
[python_tensor_converter] # python_tensor_converter
|
||||
tensorflow::PythonTensorConverter
|
||||
|
||||
[python_api_info] # python_api_info
|
||||
tensorflow::PythonAPIInfo
|
||||
|
Loading…
Reference in New Issue
Block a user