diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 1435b6852a8..55480f5b95e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1371,6 +1371,7 @@ py_library( ":_pywrap_py_exception_registry", ":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed. ":_pywrap_python_api_dispatcher", + ":_pywrap_python_api_info", ":_pywrap_python_op_gen", ":_pywrap_quantize_training", ":_pywrap_stacktrace_handler", @@ -1696,6 +1697,7 @@ cc_library( ":cpp_python_util", ":safe_pyobject_ptr", "//tensorflow/core:protos_all_cc", + "//third_party/python_runtime:headers", # buildcleaner: keep "@com_google_absl//absl/strings", ], ) @@ -1739,6 +1741,86 @@ tf_py_test( tags = ["no_pip"], ) +cc_library( + name = "python_api_info", + srcs = ["framework/python_api_info.cc"], + hdrs = ["framework/python_api_info.h"], + deps = [ + ":cpp_python_util", + ":op_def_util_cc", + ":python_tensor_converter", + ":safe_pyobject_ptr", + "//tensorflow/core:framework", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:status", + "//tensorflow/python/eager:pywrap_tfe_lib", + "//third_party/python_runtime:headers", # buildcleaner: keep + "@com_google_absl//absl/strings", + ], +) + +# Note: this target is only used by python_api_info_test. +tf_python_pybind_extension( + name = "_pywrap_python_api_info", + srcs = ["framework/python_api_info_wrapper.cc"], + hdrs = [ + "framework/op_def_util.h", + "framework/python_api_info.h", + "framework/python_tensor_converter.h", + "lib/core/numpy.h", + "//tensorflow/c:headers", + "//tensorflow/c/eager:pywrap_required_hdrs", + "//tensorflow/c/experimental/ops:pywrap_required_hdrs", + "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs", + "//tensorflow/core/distributed_runtime:pywrap_required_hdrs", + "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs", + "//tensorflow/python/eager:pywrap_required_hdrs", + ], + module_name = "_pywrap_python_api_info", + deps = [ + ":safe_pyobject_ptr_required_hdrs", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@pybind11", + "//third_party/python_runtime:headers", # buildcleaner: keep + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:lib", + "//tensorflow/core:framework", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", + "//tensorflow/core:lib_headers_for_pybind", + "//third_party/py/numpy:headers", + "//tensorflow/c:pywrap_required_hdrs", + "@com_google_absl//absl/types:span", + ] + if_static( + extra_deps = [ + "//tensorflow/core/protobuf:eager_service_proto_cc", + "//tensorflow/core/protobuf:master_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", + ], + otherwise = [ + "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", + "//tensorflow/core/protobuf:master_proto_cc_headers_only", + "//tensorflow/core/protobuf:worker_proto_cc_headers_only", + ], + ), +) + +tf_py_test( + name = "python_api_info_test", + srcs = ["framework/python_api_info_test.py"], + python_version = "PY3", + tags = ["no_pip"], + deps = [ + ":_pywrap_python_api_info", + ":_pywrap_python_tensor_converter", + ":client_testlib", + ], +) + cc_library( name = "python_api_dispatcher", srcs = ["framework/python_api_dispatcher.cc"], @@ -6109,6 +6191,7 @@ pywrap_tensorflow_macro( ":pybind11_status", ":pybind11_proto", ":python_api_dispatcher", + ":python_api_info", ":python_op_gen", ":python_tensor_converter", ":safe_pyobject_ptr", @@ -6181,6 +6264,7 @@ filegroup( ":py_exception_registry", # py_exception_registry ":py_func_lib", # py_func ":python_api_dispatcher", # python_api_dispatcher + ":python_api_info", # python_api_info ":python_tensor_converter", # python_tensor_converter ":python_op_gen", # python_op_gen ":safe_ptr", # checkpoint_reader diff --git a/tensorflow/python/framework/python_api_info.cc b/tensorflow/python/framework/python_api_info.cc new file mode 100644 index 00000000000..7c93afe0757 --- /dev/null +++ b/tensorflow/python/framework/python_api_info.cc @@ -0,0 +1,508 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/python/framework/python_api_info.h" + +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/python/eager/pywrap_tensor.h" +#include "tensorflow/python/eager/pywrap_tfe.h" +#include "tensorflow/python/framework/op_def_util.h" +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" +#include "tensorflow/python/util/util.h" + +namespace tensorflow { + +#if PY_MAJOR_VERSION < 3 +// Python 2.x: +#define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x)) +#define PY_INT_AS_LONG(x) (PyInt_AsLong(x)) +#define PY_STRING_FROMSTRING(x) (PyString_FromString(x)) +#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x)) +#define PY_STRING_AS_CSTR(x) (PyString_AsString(x)) +#else +// Python 3.x: +#define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x)) +#define PY_INT_AS_LONG(x) (PyLong_AsLong(x)) +#define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x)) +#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x)) +#define PY_STRING_AS_CSTR(x) (PyUnicode_AsUTF8AndSize((x), nullptr)) +#endif + +namespace { + +// Converts the given object to an interned Python string, and returns its +// data pointer. (This means we don't need to worry about ownership for +// this string.) +const char* InternPyString(const std::string& s) { + Safe_PyObjectPtr interned(PY_STRING_INTERN_FROM_STRING(s.c_str())); + return PY_STRING_AS_CSTR(interned.get()); +} + +template +void RemoveIf(UnaryPredicate p, std::vector* vec) { + vec->erase(std::remove_if(vec->begin(), vec->end(), p), vec->end()); +} + +struct DataTypeFormatter { + void operator()(std::string* out, DataType dtype) const { + out->append(DataType_Name(dtype)); + } +}; + +// Populates `param_names` and `defaults_tuple` based on the given OpDef. +void GetOpDefNamesAndDefaults(const tensorflow::OpDef& op_def, + std::vector& param_names, + Safe_PyObjectPtr& defaults_tuple) { + param_names.reserve(op_def.input_arg_size() + op_def.attr_size()); + std::set inferred_attrs; + + // Input parameters come first, in the order they occur in the OpDef. + for (const auto& input : op_def.input_arg()) { + param_names.push_back(input.name()); + if (!input.type_attr().empty()) { + inferred_attrs.insert(input.type_attr()); + } + if (!input.type_list_attr().empty()) { + inferred_attrs.insert(input.type_list_attr()); + } + if (!input.number_attr().empty()) { + inferred_attrs.insert(input.number_attr()); + } + } + + // Next come attribute params without defaults, followed by attributes with + // defaults (but inferred attributes are not included). + std::vector param_names_with_default; + std::vector defaults; + for (const auto& attr : op_def.attr()) { + if (inferred_attrs.count(attr.name()) == 0) { + if (attr.has_default_value()) { + param_names_with_default.push_back(attr.name()); + defaults.push_back(AttrValueToPyObject(attr.default_value())); + } else { + param_names.push_back(attr.name()); + } + } + } + param_names.insert(param_names.end(), param_names_with_default.begin(), + param_names_with_default.end()); + + // Finally, the 'name' parameter comes at the end, and its default value + // is the operation's name. + param_names.push_back("name"); + defaults.emplace_back(PY_STRING_FROMSTRING(op_def.name().c_str())); + + defaults_tuple.reset(PyTuple_New(defaults.size())); + for (int i = 0; i < defaults.size(); ++i) { + PyTuple_SET_ITEM(defaults_tuple.get(), i, defaults[i].release()); + } +} + +} // namespace + +PythonAPIInfo::PythonAPIInfo(const std::string& api_name) + : api_name_(InternPyString(api_name)) {} + +Status PythonAPIInfo::Initialize(const OpDef& op_def, + const std::vector param_names, + PyObject* defaults_tuple) { + // Intern the parameter names. + param_names_.reserve(param_names.size()); + for (const auto& param_name : param_names) { + param_names_.push_back(InternPyString(param_name)); + } + + Py_INCREF(defaults_tuple); + defaults_tuple_.reset(defaults_tuple); + + // Build an index to look up parameter index by name. (Does not include + // inferred attributes.) + std::map param_name_to_index; + for (int i = 0; i < param_names_.size(); ++i) { + param_name_to_index[param_names_[i]] = i; + } + + // Initialize each attribute & input parameter. + attributes_.reserve(op_def.attr_size()); + for (const auto& attr_def : op_def.attr()) { + TF_RETURN_IF_ERROR(InitializeAttribute(attr_def, param_name_to_index)); + } + + inputs_.reserve(op_def.input_arg_size()); + for (const auto& arg_def : op_def.input_arg()) { + TF_RETURN_IF_ERROR(InitializeInput(arg_def, param_name_to_index)); + } + + TF_RETURN_IF_ERROR(CheckParamNames()); + + // Filter out any unused entries from inputs_with_*_attrs_. + RemoveIf( + [](const InputsWithTypeAttr& input) { + return input.tensor_params.empty() && input.tensor_list_params.empty(); + }, + &inputs_with_type_attrs_); + RemoveIf( + [](const InputsWithTypeListAttr& input) { + return input.tensor_list_params.empty(); + }, + &inputs_with_type_list_attrs_); + RemoveIf( + [](const InputsWithNumberAttr& input) { + return input.tensor_list_params.empty(); + }, + &inputs_with_number_attrs_); + + return Status::OK(); +} + +Status PythonAPIInfo::CheckParamNames() const { + std::vector param_found(param_names_.size()); + for (const auto& attr : attributes_) { + if (attr.index != -1) { + param_found[attr.index] = true; + } + } + for (const auto& input : inputs_) { + param_found[input.index] = true; + } + + for (int i = 0; i < param_names_.size(); ++i) { + if (param_names_[i] == std::string("name")) { + continue; + } + if (!param_found[i]) { + return errors::InvalidArgument( + api_name_, ": missing specification for parameter ", param_names_[i]); + } + } + return Status::OK(); +} + +Status PythonAPIInfo::InitializeFromRegisteredOp(const std::string& op_name) { + const tensorflow::OpDef* op_def = nullptr; + TF_RETURN_IF_ERROR( + tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def)); + std::vector param_names; + Safe_PyObjectPtr defaults_tuple; + GetOpDefNamesAndDefaults(*op_def, param_names, defaults_tuple); + TF_RETURN_IF_ERROR(Initialize(*op_def, param_names, defaults_tuple.get())); + return Status::OK(); +} + +Status PythonAPIInfo::InitializeFromParamSpecs( + const std::map& input_specs, + const std::map& attr_specs, + const std::vector param_names, PyObject* defaults_tuple) { + OpDefBuilder op_def_builder(api_name_); + op_def_builder.AllowAttrTypeAny(); + for (const auto& attr_spec : attr_specs) { + op_def_builder.Attr(absl::StrCat(attr_spec.first, ": ", attr_spec.second)); + } + for (const auto& input_spec : input_specs) { + op_def_builder.Input( + absl::StrCat(input_spec.first, ": ", input_spec.second)); + } + OpRegistrationData op_reg_data; + TF_RETURN_IF_ERROR(op_def_builder.Finalize(&op_reg_data)); + + TF_RETURN_IF_ERROR( + Initialize(op_reg_data.op_def, param_names, defaults_tuple)); + + return Status::OK(); +} + +Status PythonAPIInfo::InitializeAttribute( + const OpDef::AttrDef& attr_def, + const std::map& param_name_to_index) { + if (attr_def.name() == "name") { + return errors::InvalidArgument( + api_name_, ": Reserved parameter `name` was used as an attribute."); + } + const char* name = InternPyString(attr_def.name()); + + const int param_index = + gtl::FindWithDefault(param_name_to_index, attr_def.name(), -1); + const AttributeType dtype = AttributeTypeFromName(attr_def.type()); + const int inferred_index = -1; + attributes_.push_back({param_index, dtype, name, inferred_index}); + Attribute& attr = attributes_.back(); + if (attr.type == AttributeType::UNKNOWN) { + return errors::InvalidArgument(api_name_, ": Bad attribute type for ", + attr_def.name(), ": '", attr_def.type(), + "'"); + } + std::vector* ok_dtypes = nullptr; + + if (attr.type == AttributeType::DTYPE) { + DataType default_dtype = attr_def.has_default_value() + ? attr_def.default_value().type() + : DT_INVALID; + inputs_with_type_attrs_.push_back({&attr, default_dtype}); + ok_dtypes = &inputs_with_type_attrs_.back().ok_dtypes; + + } else if (attr.type == AttributeType::LIST_DTYPE) { + inputs_with_type_list_attrs_.push_back({&attr}); + for (int d : attr_def.default_value().list().type()) { + inputs_with_type_list_attrs_.back().default_dtypes.push_back( + static_cast(d)); + } + ok_dtypes = &inputs_with_type_list_attrs_.back().ok_dtypes; + } + + if (attr_def.has_allowed_values() && ok_dtypes) { + const auto& dtypes = attr_def.allowed_values().list(); + for (int i = 0; i < dtypes.type_size(); ++i) { + ok_dtypes->push_back(dtypes.type(i)); + } + } + + if (attr.type == AttributeType::INT) { + int64 default_len = + attr_def.has_default_value() ? attr_def.default_value().i() : -1; + inputs_with_number_attrs_.push_back({&attr, default_len}); + } + + // If this is an inferred attribute, then record its name and index. + if (attr.index == -1) { + std::vector* inferred_attr_names = + attr.type == AttributeType::DTYPE ? &inferred_type_attrs_ + : attr.type == AttributeType::LIST_DTYPE ? &inferred_type_list_attrs_ + : attr.type == AttributeType::INT ? &inferred_length_attrs_ + : nullptr; + if (inferred_attr_names == nullptr) { + return errors::InvalidArgument( + api_name_, ": Missing specification for parameter ", attr_def.name()); + } else { + attr.inferred_index = inferred_attr_names->size(); + inferred_attr_names->push_back(attr.name); + } + } + + return Status::OK(); +} + +Status PythonAPIInfo::InitializeInput( + const OpDef::ArgDef& arg_def, + const std::map& param_name_to_index) { + if (arg_def.name() == "name") { + return errors::InvalidArgument( + api_name_, ": Reserved parameter `name` was used as a tensor input."); + } + const ParamIndex param_index = + gtl::FindWithDefault(param_name_to_index, arg_def.name(), -1); + if (param_index == -1) { + return errors::InvalidArgument( + api_name_, ": Missing specification for parameter ", arg_def.name()); + } + if (arg_def.is_ref()) { + // TODO(b/164980194): Support reference parameters. + // - Pass as_ref to convert_to_tensor + // - Check that values for ref inputs have ref types. + return errors::InvalidArgument(api_name_, + ": PythonAPIInfo doesn't support reference " + "parameters yet."); + } + bool is_list = + !arg_def.number_attr().empty() || !arg_def.type_list_attr().empty(); + inputs_.push_back({param_index, is_list}); + + if (!arg_def.type_list_attr().empty()) { + // list(input) with dtypes specified by a `list(type)` attribute. + InputsWithTypeListAttr* input = + FindInputsWithTypeListAttr(arg_def.type_list_attr()); + if (!input) { + return errors::InvalidArgument( + api_name_, ": Type attribute ", arg_def.type_list_attr(), + " for parameter ", arg_def.name(), " not found."); + } + input->tensor_list_params.push_back(param_index); + } else if (!arg_def.type_attr().empty()) { + InputsWithTypeAttr* input = FindInputsWithTypeAttr(arg_def.type_attr()); + // input or list(input) with dtype specified by a `type` attribute. + if (!input) { + return errors::InvalidArgument(api_name_, ": Type attribute ", + arg_def.type_attr(), " for parameter ", + arg_def.name(), " not found."); + } + if (arg_def.number_attr().empty()) { + input->tensor_params.push_back(param_index); + } else { + input->tensor_list_params.push_back(param_index); + } + } else { + // input or list(input) with fixed dtype + inputs_with_fixed_dtype_.push_back({arg_def.type(), param_index, is_list}); + } + + if (!arg_def.number_attr().empty()) { + InputsWithNumberAttr* input = + FindInputsWithNumberAttr(arg_def.number_attr()); + if (!input) { + return errors::InvalidArgument(api_name_, ": Length attribute ", + arg_def.number_attr(), " for parameter ", + arg_def.name(), " not found."); + } + input->tensor_list_params.push_back(param_index); + } + + return Status::OK(); +} + +PythonAPIInfo::InputsWithTypeAttr* PythonAPIInfo::FindInputsWithTypeAttr( + const string& name) { + for (auto& input : inputs_with_type_attrs_) { + if (name == input.type_attr->name) { + return &input; + } + } + return nullptr; +} + +PythonAPIInfo::InputsWithTypeListAttr* +PythonAPIInfo::FindInputsWithTypeListAttr(const string& name) { + for (auto& input : inputs_with_type_list_attrs_) { + if (name == input.type_list_attr->name) { + return &input; + } + } + return nullptr; +} + +PythonAPIInfo::InputsWithNumberAttr* PythonAPIInfo::FindInputsWithNumberAttr( + const string& name) { + for (auto& input : inputs_with_number_attrs_) { + if (name == input.number_attr->name) { + return &input; + } + } + return nullptr; +} + +string PythonAPIInfo::DebugInfo() const { + string s = absl::StrCat("DebugInfo for ", api_name_, ":\n"); + absl::StrAppend(&s, " param_names=[", absl::StrJoin(param_names_, ", "), + "]\n"); + Safe_PyObjectPtr defaults_repr(PyObject_Repr(defaults_tuple_.get())); + absl::StrAppend( + &s, " defaults_tuple=", TFE_GetPythonString(defaults_repr.get()), "\n"); + if (!attributes_.empty()) { + absl::StrAppend(&s, " attributes=["); + for (const auto& attrib : attributes_) { + if (attrib.index != -1) { + absl::StrAppend(&s, "\n {index=", attrib.index); + DCHECK_EQ(attrib.inferred_index, -1); + } else { + absl::StrAppend(&s, "\n {inferred_index=", attrib.inferred_index); + } + absl::StrAppend(&s, ", name=", attrib.name, + ", type=", AttributeTypeToName(attrib.type), "},"); + } + absl::StrAppend(&s, "]\n"); + } + if (!inputs_.empty()) { + absl::StrAppend(&s, " inputs=["); + for (const auto& input : inputs_) { + absl::StrAppend(&s, "\n {index=", input.index, + ", name=", param_names_[input.index], + ", is_list=", input.is_list, "},"); + } + absl::StrAppend(&s, "]\n"); + } + if (!inputs_with_fixed_dtype_.empty()) { + absl::StrAppend(&s, " inputs_with_fixed_dtype=["); + for (const auto& input : inputs_with_fixed_dtype_) { + absl::StrAppend(&s, "\n {index=", input.index, + ", dtype=", DataType_Name(input.dtype), + ", is_list=", input.is_list, "},"); + } + absl::StrAppend(&s, "]\n"); + } + if (!inputs_with_type_attrs_.empty()) { + absl::StrAppend(&s, " inputs_with_type_attr=["); + for (const auto& input : inputs_with_type_attrs_) { + absl::StrAppend(&s, "\n {type_attr=", input.type_attr->name); + if (input.default_dtype != DT_INVALID) { + absl::StrAppend(&s, + ", default_dtype=", DataType_Name(input.default_dtype)); + } + if (!input.tensor_params.empty()) { + absl::StrAppend(&s, ", tensor_params=[", + absl::StrJoin(input.tensor_params, ", "), "]"); + } + if (!input.tensor_list_params.empty()) { + absl::StrAppend(&s, ", tensor_list_params=[", + absl::StrJoin(input.tensor_list_params, ", "), "]"); + } + if (!input.ok_dtypes.empty()) { + absl::StrAppend( + &s, ", ok_dtypes=[", + absl::StrJoin(input.ok_dtypes, ", ", DataTypeFormatter()), "]"); + } + absl::StrAppend(&s, "},"); + } + absl::StrAppend(&s, "]\n"); + } + if (!inputs_with_type_list_attrs_.empty()) { + absl::StrAppend(&s, " inputs_with_type_list_attrs=["); + for (const auto& input : inputs_with_type_list_attrs_) { + absl::StrAppend(&s, "\n {type_list_attr=", input.type_list_attr->name); + if (!input.default_dtypes.empty()) { + absl::StrAppend( + &s, ", default_dtypes=[", + absl::StrJoin(input.default_dtypes, ", ", DataTypeFormatter()), + "]"); + } + if (!input.tensor_list_params.empty()) { + absl::StrAppend(&s, ", tensor_list_params=[", + absl::StrJoin(input.tensor_list_params, ", "), "]"); + } + if (!input.ok_dtypes.empty()) { + absl::StrAppend( + &s, ", ok_dtypes=[", + absl::StrJoin(input.ok_dtypes, ", ", DataTypeFormatter()), "]"); + } + absl::StrAppend(&s, "},"); + } + absl::StrAppend(&s, "]\n"); + } + if (!inputs_with_number_attrs_.empty()) { + absl::StrAppend(&s, " inputs_with_number_attrs=["); + for (const auto& input : inputs_with_number_attrs_) { + absl::StrAppend(&s, "\n {number_attr=", input.number_attr->name, + ", default_length=", input.default_length, + ", tensor_list_params=[", + absl::StrJoin(input.tensor_list_params, ", "), "],\n"); + } + absl::StrAppend(&s, "]\n"); + } + if (!inferred_type_attrs_.empty()) { + absl::StrAppend(&s, " inferred_type_attrs=[", + absl::StrJoin(inferred_type_attrs_, ", "), "]\n"); + } + if (!inferred_type_list_attrs_.empty()) { + absl::StrAppend(&s, " inferred_type_list_attrs=[", + absl::StrJoin(inferred_type_list_attrs_, ", "), "]\n"); + } + if (!inferred_length_attrs_.empty()) { + absl::StrAppend(&s, " inferred_length_attrs=[", + absl::StrJoin(inferred_length_attrs_, ", "), "]\n"); + } + return s; +} + +} // namespace tensorflow diff --git a/tensorflow/python/framework/python_api_info.h b/tensorflow/python/framework/python_api_info.h new file mode 100644 index 00000000000..4da710fbbd9 --- /dev/null +++ b/tensorflow/python/framework/python_api_info.h @@ -0,0 +1,298 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_ +#define TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_ + +#include + +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/python/framework/op_def_util.h" +#include "tensorflow/python/framework/python_tensor_converter.h" +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" + +namespace tensorflow { + +// Precomputed information about a TensorFlow Python API. +// +// PythonAPIInfo records information about a single TensorFlow Python API, +// in order to allow calls to the API to be executed more efficiently. This +// information includes: +// +// * The name of the API. (E.g. "tf.math.add") +// +// * The name of the registered op that implements the API, if applicable +// (e.g. "AddV2"). +// +// * Information about the API's parameters. Parameters are divided into two +// "kinds": inputs and attributes. An *input* is a parameter that +// expects a Tensor or list of Tensors, and it is described by an `ArgDef`. +// An *attribute* is a parameter that expects any other value type, and it is +// described by an `AttrDef`. +// +// * Default values for the API's attribute parameters. +// +// * Information about "inferred attributes" -- attributes whose values are +// inferred from `input` parameters. There are two kinds of inferred +// attributes: Tensor dtypes, which are inferred from tensor and list(tensor) +// parameters; and list lengths, which are inferred from list(tensor) +// parameters. +class PythonAPIInfo { + public: + // The index of a parameter in the canonicalized parameter list. The + // canonicalized parameter list includes inputs and attributes (but does + // not include inferred attributes). `-1` is used for inferred attributes. + using ParamIndex = int; + + // Information about a parameter that expects a non-Tensor value. + struct Attribute { + ParamIndex index; // -1 if this is an inferred attribute + AttributeType type; + const char* name; // Interned python string + int inferred_index; // index to store attribute in InferredAttributes + }; + + // Information about a parameter that expects a Tensor or list(Tensor). + // Additional information about tensor parameters is stored in types + // defined below, in order to simplify dtype/length inference: + // * FixedDTypeInput: inputs with fixed dtypes. + // * InputsWithTypeAttr: groups inputs that use a type_attr for dtype. + // * InputsWithTypeListAttr: groups inputs that use a type_list_attr. + // * InputsWithNumberAttr: groups inputs by a number_attr for length. + struct Input { + ParamIndex index; + bool is_list; + }; + + // Information about a Tensor parameter w/ fixed dtype. + struct InputWithFixedDType { + DataType dtype; + ParamIndex index; + bool is_list; + }; + + // Information about Tensor parameters whose DType is specified by a single + // `type_attr` attribute. + struct InputsWithTypeAttr { + Attribute* type_attr; // not owned. + DataType default_dtype; // DT_INVALID if no default. + std::vector tensor_params; // single-tensor inputs. + std::vector tensor_list_params; // list(tensor) inputs. + std::vector ok_dtypes; + }; + + // Information about Tensor parameters whose DType is specified by a single + // `type_list_attr` attribute. + struct InputsWithTypeListAttr { + Attribute* type_list_attr; // not owned. + std::vector default_dtypes; // empty if no default. + std::vector tensor_list_params; // list(tensor) inputs. + std::vector ok_dtypes; + }; + + // Information about Tensor-list parameters whose length is specified by a + // single `int` attribute. + struct InputsWithNumberAttr { + Attribute* number_attr; // not owned. + int64 default_length; // -1 for no default. + std::vector tensor_list_params; // list(tensor) inputs. + }; + + // Structure used to return inferred attribute values. + // * types[i] is the inferred value for inferred_type_attrs()[i] + // * type_lists[i] is the inferred value for inferred_type_list_attrs()[i] + // * lengths[i] is the inferred value for inferred_length_attrs()[i] + struct InferredAttributes { + std::vector types; + std::vector> type_lists; + std::vector lengths; + }; + + // Constructs a new PythonAPIInfo. + // + // Note: One of the `Initialize()` functions must be called before the + // `PythonAPIInfo` is used. + // + // Args: + // api_name: The fully-qualified name of the python API (e.g., tf.math.sum). + explicit PythonAPIInfo(const std::string& api_name); + + // Initializes this PythonAPIInfo. + // + // Args: + // op_def: Contains information about the parameters. + // param_names: The argument names for the python API, in canonical order. + // defaults_tuple: Tuple containing default values for the parameters, + // right-aligned with `param_names` -- i.e., `defaults[-i]` is the default + // for `param_names[-i]`. + Status Initialize(const OpDef& op_def, const std::vector param_names, + PyObject* defaults_tuple); + + // Initialize this PythonAPIInfo based on the registered OpDef for the given + // operation. + // + // Args: + // op_name: The registered name of the operation (e.g. "AddV2"). + Status InitializeFromRegisteredOp(const std::string& op_name); + + // Initializes this PythonAPIInfo based on a set of parameter specifications. + // + // Args: + // input_specs: Mapping from parameter name to specification string for + // each input (parameter that expects a tensor value). + // attr_specs: Mapping from parameter name to specification string for + // each attribute (parameter that expects a non-tensor value). + // param_names: The argument names for the python API, in canonical order. + // defaults_tuple: Tuple containing default values for the parameters, + // right-aligned with `param_names` -- i.e., `defaults[-i]` is the default + // for `param_names[-i]`. + // + // Note: the `name` parameter should not be included in `input_specs` or + // `attr_specs`. + Status InitializeFromParamSpecs( + const std::map& input_specs, + const std::map& attr_specs, + const std::vector param_names, PyObject* defaults_tuple); + + // The name of the API that is described by this PythonAPIInfo. + const char* api_name() const { return api_name_; } + + // The ordered names of the canononical parameters that this API expects. + const std::vector& param_names() const { return param_names_; } + + // A Python tuple containing the default values for parameters. This is + // right-aligned with `param_name` -- i.e., `defaults[-i]` is the default + // for `param_names[-i]`. + const PyObject* defaults_tuple() const { return defaults_tuple_.get(); } + + // Information about the attribute (non-tensor) parameters for this API. + const std::vector& attributes() const { return attributes_; } + + // Information about the input (tensor) parameters for this API. + const std::vector& inputs() const { return inputs_; } + const std::vector& inputs_with_fixed_dtype() const { + return inputs_with_fixed_dtype_; + } + const std::vector& inputs_with_type_attrs() const { + return inputs_with_type_attrs_; + } + const std::vector& inputs_with_type_list_attrs() + const { + return inputs_with_type_list_attrs_; + } + const std::vector& inputs_with_number_attrs() const { + return inputs_with_number_attrs_; + } + + // Names of inferred attributes. + const std::vector& inferred_type_attrs() const { + return inferred_type_attrs_; + } + const std::vector& inferred_type_list_attrs() const { + return inferred_type_list_attrs_; + } + const std::vector& inferred_length_attrs() const { + return inferred_length_attrs_; + } + + // Returns a string summarizing the internal state of this type converter. + string DebugInfo() const; + + private: + // Adds an entry to the attributes_ vector based on the given `AttrDef`. + // + // If `attr_def` describes a type attribute, then adds a value to + // inputs_with_type_attrs_ or inputs_with_type_list_attrs_ (to record any + // tensor inputs that use this dtype). + // + // If `attr_def` describes an int attribute, then adds a value to + // inputs_with_number_attrs_ (to record any tensor inputs that use this + // value as a list length). + Status InitializeAttribute( + const OpDef::AttrDef& attr_def, + const std::map& param_name_to_index); + + // Adds an entry to the inputs_ vector based on the given `ArgDef`. + // + // If `arg_def` has a fixed dtype, then adds a value to `fixed_dtype_inputs`. + // + // If `arg_def`'s dtype is described by a `type` attr, then updates the + // appropriate value in `inputs_with_type_attrs_` with information about the + // `arg_def`. + // + // If `arg_def`'s dtype is described by a `list(type)` attr, then updates the + // appropriate value in `inputs_with_type_list_attrs_` with information about + // the `arg_def`. + Status InitializeInput(const OpDef::ArgDef& arg_def, + const std::map& param_name_to_index); + + // Checks that the OpDef used to initialize this PythonAPIInfo + // had an AttrDef or ArgDef specification for each parameter. + Status CheckParamNames() const; + + // Searches inputs_with_type_attrs_ for an input with the given name. + InputsWithTypeAttr* FindInputsWithTypeAttr(const string& name); + + // Searches inputs_with_type_list_attrs_ for an input with the given name. + InputsWithTypeListAttr* FindInputsWithTypeListAttr(const string& name); + + // Searches inputs_with_type_list_attrs_ for an input with the given name. + InputsWithNumberAttr* FindInputsWithNumberAttr(const string& name); + + ABSL_MUST_USE_RESULT + bool InferLengthAttributes(const absl::Span params, + std::vector& inferred_length_attrs) const; + + // ========================================================================== + // Member Variables + // ========================================================================== + + // The name of the API that is described by this PythonAPIInfo. + // (Interned python string). + const char* api_name_; + + // The names of the parameters that this API expects. + // (Interned python strings.) + std::vector param_names_; + + // Tuple containing default values for the parameters, right-aligned with + // `param_names` -- i.e., `defaults[-i]` is the default for `param_names[-i]`. + Safe_PyObjectPtr defaults_tuple_; + + // Information about the non-tensor-valued parameters that this API expects. + std::vector attributes_; + + // Information about the tensor-valued parameters that this API expects. + std::vector inputs_; + std::vector inputs_with_fixed_dtype_; + std::vector inputs_with_type_attrs_; + std::vector inputs_with_type_list_attrs_; + std::vector inputs_with_number_attrs_; + + // Names of inferred attributes. (Interned python strings.) + std::vector inferred_type_attrs_; + std::vector inferred_type_list_attrs_; + std::vector inferred_length_attrs_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_ diff --git a/tensorflow/python/framework/python_api_info_test.py b/tensorflow/python/framework/python_api_info_test.py new file mode 100644 index 00000000000..f8c9df1beaf --- /dev/null +++ b/tensorflow/python/framework/python_api_info_test.py @@ -0,0 +1,254 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.python_api_info.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python import _pywrap_python_api_info +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + +# pylint: disable=g-long-lambda + + +# Helper function to make expected output in examples more compact: +def Const(x): + return constant_op.constant(x) + + +@test_util.run_all_in_graph_and_eager_modes +class PythonAPIInfoTest(test_util.TensorFlowTestCase, parameterized.TestCase): + + def setUp(self): + context.ensure_initialized() + super(PythonAPIInfoTest, self).setUp() + + def makeConverterForGenOp(self, op_name): + """Returns a PythonAPIInfo for the given gen_op.""" + api_info = _pywrap_python_api_info.PythonAPIInfo(op_name) + api_info.InitializeFromRegisteredOp(op_name) + return api_info + + def makeConverterFromParamSpecs(self, + api_name, + param_names, + input_specs, + attr_specs, + defaults=()): + """Returns a PythonAPIInfo built from the given specs.""" + api_info = _pywrap_python_api_info.PythonAPIInfo(api_name) + api_info.InitializeFromParamSpecs(input_specs, attr_specs, param_names, + defaults) + return api_info + + # This test initializes a PythonAPIInfo from a registered + # op, and then uses DebugInfo() to check that the internal state is + # correct. + @parameterized.named_parameters([ + # An op whose inputs have fixed dtypes. + ("RegexFullMatch", "RegexFullMatch", "DebugInfo for RegexFullMatch:\n" + " param_names=[input, pattern, name]\n" + " defaults_tuple=('RegexFullMatch',)\n" + " inputs=[\n" + " {index=0, name=input, is_list=0},\n" + " {index=1, name=pattern, is_list=0},]\n" + " inputs_with_fixed_dtype=[\n" + " {index=0, dtype=DT_STRING, is_list=0},\n" + " {index=1, dtype=DT_STRING, is_list=0},]\n"), + # An op whose input has a variable dtype. + ("Abs", "Abs", "DebugInfo for Abs:\n" + " param_names=[x, name]\n" + " defaults_tuple=('Abs',)\n" + " attributes=[\n" + " {inferred_index=0, name=T, type=type},]\n" + " inputs=[\n" + " {index=0, name=x, is_list=0},]\n" + " inputs_with_type_attr=[\n" + " {type_attr=T, tensor_params=[0], ok_dtypes=[DT_BFLOAT16, DT_HALF, " + "DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64]},]\n" + " inferred_type_attrs=[T]\n"), + # An op with two inputs that have the same (variable) dtype. + ("AddV2", "AddV2", "DebugInfo for AddV2:\n" + " param_names=[x, y, name]\n" + " defaults_tuple=('AddV2',)\n" + " attributes=[\n" + " {inferred_index=0, name=T, type=type},]\n" + " inputs=[\n" + " {index=0, name=x, is_list=0},\n" + " {index=1, name=y, is_list=0},]\n" + " inputs_with_type_attr=[\n" + " {type_attr=T, tensor_params=[0, 1], ok_dtypes=[DT_BFLOAT16, " + "DT_HALF, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_INT8, DT_INT16, DT_UINT32, " + "DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128]},]\n" + " inferred_type_attrs=[T]\n"), + # An op with an int attribute. + ("GatherV2", "GatherV2", "DebugInfo for GatherV2:\n" + " param_names=[params, indices, axis, batch_dims, name]\n" + " defaults_tuple=(0, 'GatherV2')\n" + " attributes=[\n" + " {index=3, name=batch_dims, type=int},\n" + " {inferred_index=0, name=Tparams, type=type},\n" + " {inferred_index=1, name=Tindices, type=type},\n" + " {inferred_index=2, name=Taxis, type=type},]\n" + " inputs=[\n" + " {index=0, name=params, is_list=0},\n" + " {index=1, name=indices, is_list=0},\n" + " {index=2, name=axis, is_list=0},]\n" + " inputs_with_type_attr=[\n" + " {type_attr=Tparams, tensor_params=[0]},\n" + " {type_attr=Tindices, tensor_params=[1], " + "ok_dtypes=[DT_INT32, DT_INT64]},\n" + " {type_attr=Taxis, tensor_params=[2], " + "ok_dtypes=[DT_INT32, DT_INT64]},]\n" + " inferred_type_attrs=[Tparams, Tindices, Taxis]\n"), + # An op with default attrib values. + ("ReduceJoin", "ReduceJoin", "DebugInfo for ReduceJoin:\n" + " param_names=[inputs, reduction_indices, keep_dims, separator, name]\n" + " defaults_tuple=(False, '', 'ReduceJoin')\n" + " attributes=[\n" + " {index=2, name=keep_dims, type=bool},\n" + " {index=3, name=separator, type=string},]\n" + " inputs=[\n" + " {index=0, name=inputs, is_list=0},\n" + " {index=1, name=reduction_indices, is_list=0},]\n" + " inputs_with_fixed_dtype=[\n" + " {index=0, dtype=DT_STRING, is_list=0},\n" + " {index=1, dtype=DT_INT32, is_list=0},]\n"), + # An op with a variable-dtype list input, and an int attribute. + ("ParseExampleV2", "ParseExampleV2", "DebugInfo for ParseExampleV2:\n" + " param_names=[serialized, names, sparse_keys, dense_keys, " + "ragged_keys, dense_defaults, num_sparse, sparse_types, " + "ragged_value_types, ragged_split_types, dense_shapes, name]\n" + " defaults_tuple=('ParseExampleV2',)\n" + " attributes=[\n" + " {inferred_index=0, name=Tdense, type=list(type)},\n" + " {index=6, name=num_sparse, type=int},\n" + " {index=7, name=sparse_types, type=list(type)},\n" + " {index=8, name=ragged_value_types, type=list(type)},\n" + " {index=9, name=ragged_split_types, type=list(type)},\n" + " {index=10, name=dense_shapes, type=list(shape)},]\n" + " inputs=[\n" + " {index=0, name=serialized, is_list=0},\n" + " {index=1, name=names, is_list=0},\n" + " {index=2, name=sparse_keys, is_list=0},\n" + " {index=3, name=dense_keys, is_list=0},\n" + " {index=4, name=ragged_keys, is_list=0},\n" + " {index=5, name=dense_defaults, is_list=1},]\n" + " inputs_with_fixed_dtype=[\n" + " {index=0, dtype=DT_STRING, is_list=0},\n" + " {index=1, dtype=DT_STRING, is_list=0},\n" + " {index=2, dtype=DT_STRING, is_list=0},\n" + " {index=3, dtype=DT_STRING, is_list=0},\n" + " {index=4, dtype=DT_STRING, is_list=0},]\n" + " inputs_with_type_list_attrs=[\n" + " {type_list_attr=Tdense, tensor_list_params=[5], " + "ok_dtypes=[DT_FLOAT, DT_INT64, DT_STRING]},]\n" + " inferred_type_list_attrs=[Tdense]\n"), + # An op with a default dtype + ("BroadcastArgs", "BroadcastArgs", "DebugInfo for BroadcastArgs:\n" + " param_names=[s0, s1, name]\n" + " defaults_tuple=('BroadcastArgs',)\n" + " attributes=[\n" + " {inferred_index=0, name=T, type=type},]\n" + " inputs=[\n" + " {index=0, name=s0, is_list=0},\n" + " {index=1, name=s1, is_list=0},]\n" + " inputs_with_type_attr=[\n" + " {type_attr=T, default_dtype=DT_INT32, tensor_params=[0, 1], " + "ok_dtypes=[DT_INT32, DT_INT64]},]\n" + " inferred_type_attrs=[T]\n"), + ]) + def testInitializeFromRegisteredOp(self, op_name, debug_info): + api_info = self.makeConverterForGenOp(op_name) + self.assertEqual(api_info.DebugInfo().strip(), debug_info.strip()) + + # This test initializes a PythonAPIInfo from parameter specs, + # and then uses DebugInfo() to check that the internal state is correct. + @parameterized.named_parameters([ + ("NoParams", "NoParams", [], {}, {}, "DebugInfo for NoParams:\n" + " param_names=[]\n" + " defaults_tuple=()\n"), + ("OnlyNameParam", "OnlyNameParam", ["name"], {}, {}, + "DebugInfo for OnlyNameParam:\n" + " param_names=[name]\n" + " defaults_tuple=()\n"), + ("SomeBinaryOp", "SomeBinaryOp", ["x", "y"], dict(x="T", y="T"), + dict(T="type"), "DebugInfo for SomeBinaryOp:\n" + " param_names=[x, y]\n" + " defaults_tuple=()\n" + " attributes=[\n" + " {inferred_index=0, name=T, type=type},]\n" + " inputs=[\n" + " {index=0, name=x, is_list=0},\n" + " {index=1, name=y, is_list=0},]\n" + " inputs_with_type_attr=[\n" + " {type_attr=T, tensor_params=[0, 1]},]\n" + " inferred_type_attrs=[T]\n"), + ("AllAttributeTypes", "AllAttributeTypes", [ + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", + "o", "p" + ], {}, + dict( + a="any", + b="float", + c="int", + d="string", + e="bool", + f="type", + g="shape", + h="tensor", + i="list(any)", + j="list(float)", + k="list(int)", + l="list(string)", + m="list(bool)", + n="list(type)", + o="list(shape)", + p="list(tensor)"), "DebugInfo for AllAttributeTypes:\n" + " param_names=[a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p]\n" + " defaults_tuple=()\n" + " attributes=[\n" + " {index=0, name=a, type=any},\n" + " {index=1, name=b, type=float},\n" + " {index=2, name=c, type=int},\n" + " {index=3, name=d, type=string},\n" + " {index=4, name=e, type=bool},\n" + " {index=5, name=f, type=type},\n" + " {index=6, name=g, type=shape},\n" + " {index=7, name=h, type=tensor},\n" + " {index=8, name=i, type=list(any)},\n" + " {index=9, name=j, type=list(float)},\n" + " {index=10, name=k, type=list(int)},\n" + " {index=11, name=l, type=list(string)},\n" + " {index=12, name=m, type=list(bool)},\n" + " {index=13, name=n, type=list(type)},\n" + " {index=14, name=o, type=list(shape)},\n" + " {index=15, name=p, type=list(tensor)},]\n"), + ]) + def testInitializeFromParamSpecs(self, api_name, param_names, input_specs, + attr_specs, debug_info): + api_info = self.makeConverterFromParamSpecs(api_name, param_names, + input_specs, attr_specs) + self.assertEqual(api_info.DebugInfo().strip(), debug_info.strip()) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/python_api_info_wrapper.cc b/tensorflow/python/framework/python_api_info_wrapper.cc new file mode 100644 index 00000000000..483e475401f --- /dev/null +++ b/tensorflow/python/framework/python_api_info_wrapper.cc @@ -0,0 +1,75 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Note: This library is only used by python_api_info_test. It +// is not meant to be used in other circumstances. + +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" +#include "tensorflow/python/framework/python_api_info.h" + +namespace py = pybind11; + +namespace tensorflow { +namespace { + +void InitializeFromRegisteredOp(PythonAPIInfo* api_info, + const std::string& op_name) { + auto result = api_info->InitializeFromRegisteredOp(op_name); + if (!result.ok()) { + PyErr_SetString(PyExc_ValueError, result.ToString().c_str()); + throw py::error_already_set(); + } +} + +void InitializeFromParamSpecs( + PythonAPIInfo* api_info, + const std::map& input_specs, + const std::map& attr_specs, + const std::vector& param_names, py::handle defaults_tuple) { + auto result = api_info->InitializeFromParamSpecs( + input_specs, attr_specs, param_names, defaults_tuple.ptr()); + if (!result.ok()) { + PyErr_SetString(PyExc_ValueError, result.ToString().c_str()); + throw py::error_already_set(); + } +} + +std::string DebugInfo(PythonAPIInfo* api_info) { return api_info->DebugInfo(); } + +} // namespace +} // namespace tensorflow + +using PythonAPIInfo = tensorflow::PythonAPIInfo; +using InferredAttributes = tensorflow::PythonAPIInfo::InferredAttributes; + +PYBIND11_MODULE(_pywrap_python_api_info, m) { + py::class_(m, "PythonAPIInfo") + .def(py::init()) + .def("InitializeFromRegisteredOp", + &tensorflow::InitializeFromRegisteredOp) + .def("InitializeFromParamSpecs", &tensorflow::InitializeFromParamSpecs) + .def("DebugInfo", &tensorflow::DebugInfo) + .def("InferredTypeAttrs", + [](PythonAPIInfo* self) { return self->inferred_type_attrs(); }) + .def("InferredTypeListAttrs", + [](PythonAPIInfo* self) { return self->inferred_type_list_attrs(); }) + .def("InferredLengthAttrs", + [](PythonAPIInfo* self) { return self->inferred_length_attrs(); }); + py::class_(m, "InferredAttributes") + .def_readonly("types", &InferredAttributes::types) + .def_readonly("type_lists", &InferredAttributes::type_lists) + .def_readonly("lengths", &InferredAttributes::lengths); +} diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index ebe1427ba71..a6082788413 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -404,3 +404,6 @@ tensorflow::PythonAPIDispatcher [python_tensor_converter] # python_tensor_converter tensorflow::PythonTensorConverter + +[python_api_info] # python_api_info +tensorflow::PythonAPIInfo