Merge pull request #40557 from rahul-kamat:gen-ops-with-type-annotations
PiperOrigin-RevId: 318862240 Change-Id: Idbc6c43c0ab9418930298d4dc71c31a8abf5df65
This commit is contained in:
commit
a8eb45b4af
@ -45,6 +45,33 @@ const int kRightMargin = 78;
|
||||
|
||||
constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
|
||||
|
||||
// Maps C++ dtype enum values to Python DType classes
|
||||
const std::unordered_map<string, string> dtype_type{
|
||||
{"_dtypes.float16", "_dtypes.Float16"},
|
||||
{"_dtypes.half", "_dtypes.Half"},
|
||||
{"_dtypes.float32", "_dtypes.Float32"},
|
||||
{"_dtypes.float64", "_dtypes.Float64"},
|
||||
{"_dtypes.bfloat16", "_dtypes.BFloat16"},
|
||||
{"_dtypes.complex64", "_dtypes.Complex64"},
|
||||
{"_dtypes.complex128", "_dtypes.Complex128"},
|
||||
{"_dtypes.int8", "_dtypes.Int8"},
|
||||
{"_dtypes.uint8", "_dtypes.UInt8"},
|
||||
{"_dtypes.uint16", "_dtypes.UInt16"},
|
||||
{"_dtypes.uint32", "_dtypes.UInt32"},
|
||||
{"_dtypes.uint64", "_dtypes.UInt64"},
|
||||
{"_dtypes.int16", "_dtypes.Int16"},
|
||||
{"_dtypes.int32", "_dtypes.Int32"},
|
||||
{"_dtypes.int64", "_dtypes.Int64"},
|
||||
{"_dtypes.bool", "_dtypes.Bool"},
|
||||
{"_dtypes.string", "_dtypes.String"},
|
||||
{"_dtypes.qint8", "_dtypes.QInt8"},
|
||||
{"_dtypes.quint8", "_dtypes.QUInt8"},
|
||||
{"_dtypes.qint16", "_dtypes.QInt16"},
|
||||
{"_dtypes.quint16", "_dtypes.QUInt16"},
|
||||
{"_dtypes.qint32", "_dtypes.QInt32"},
|
||||
{"_dtypes.resource", "_dtypes.Resource"},
|
||||
{"_dtypes.variant", "_dtypes.Variant"}};
|
||||
|
||||
string AttrVarName(const string& attr_name,
|
||||
std::unordered_map<string, string>* attr_expressions) {
|
||||
const string var = strings::StrCat("_attr_", attr_name);
|
||||
@ -106,8 +133,9 @@ string TensorPBString(const TensorProto& pb) {
|
||||
class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
|
||||
public:
|
||||
GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
|
||||
const string& function_name)
|
||||
: python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) {
|
||||
const string& function_name, bool add_type_annotations)
|
||||
: python_op_gen_internal::GenPythonOp(op_def, api_def, function_name,
|
||||
add_type_annotations) {
|
||||
op_name_ = function_name_;
|
||||
absl::ConsumePrefix(&op_name_, "_");
|
||||
}
|
||||
@ -130,13 +158,14 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
|
||||
const std::vector<string>& output_sizes,
|
||||
bool execute_record_gradient);
|
||||
|
||||
bool AddEagerFastPathAndGraphCode(const string& parameters,
|
||||
const std::vector<string>& output_sizes,
|
||||
const string& eager_not_allowed_error);
|
||||
bool AddEagerFallbackCode(const string& parameters,
|
||||
const std::vector<string>& output_sizes,
|
||||
const string& num_outputs_expr,
|
||||
const string& eager_not_allowed_error);
|
||||
bool AddEagerFastPathAndGraphCode(
|
||||
const string& parameters, const std::vector<string>& output_sizes,
|
||||
const string& eager_not_allowed_error,
|
||||
const std::unordered_map<string, string>& type_annotations);
|
||||
bool AddEagerFallbackCode(
|
||||
const string& parameters, const std::vector<string>& output_sizes,
|
||||
const string& num_outputs_expr, const string& eager_not_allowed_error,
|
||||
const std::unordered_map<string, string>& type_annotations);
|
||||
void AddEagerFastPathExecute();
|
||||
|
||||
void AddEagerInferredAttrs(const string& indentation);
|
||||
@ -148,6 +177,14 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
|
||||
|
||||
void AddRawOpExport(const string& parameters);
|
||||
|
||||
std::unordered_map<string, string> GetTypeAnnotations();
|
||||
|
||||
void GenerateTypeVars(
|
||||
const std::unordered_map<string, string>& type_annotations);
|
||||
|
||||
void AddReturnTypeAnnotation(
|
||||
const std::unordered_map<string, string>& type_annotations);
|
||||
|
||||
void AddAttrForArg(const string& attr, int arg_index) {
|
||||
gtl::InsertIfNotPresent(&inferred_attrs_, attr,
|
||||
op_def_.input_arg(arg_index).name());
|
||||
@ -179,8 +216,10 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
|
||||
};
|
||||
|
||||
string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
|
||||
const string& function_name) {
|
||||
return GenEagerPythonOp(op_def, api_def, function_name).Code();
|
||||
const string& function_name,
|
||||
bool add_type_annotations) {
|
||||
return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations)
|
||||
.Code();
|
||||
}
|
||||
|
||||
string GenEagerPythonOp::FlattenInputs(
|
||||
@ -311,19 +350,45 @@ string GenEagerPythonOp::Code() {
|
||||
param_names_.push_back(param_and_default.first);
|
||||
}
|
||||
|
||||
std::unordered_map<string, string> type_annotations;
|
||||
// Only populate map for whitelisted ops
|
||||
if (add_type_annotations_) {
|
||||
type_annotations = GetTypeAnnotations();
|
||||
}
|
||||
|
||||
string parameters;
|
||||
// Param can be an input or an attr
|
||||
for (const auto& param : params_no_default_) {
|
||||
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
|
||||
strings::StrAppend(¶meters, param.GetRenameTo());
|
||||
|
||||
if (type_annotations.find(param.GetName()) != type_annotations.end()) {
|
||||
strings::StrAppend(¶meters, ": ",
|
||||
type_annotations.at(param.GetName()));
|
||||
}
|
||||
}
|
||||
|
||||
string parameters_with_defaults = parameters;
|
||||
for (const auto& param_and_default : params_with_default_) {
|
||||
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
|
||||
if (!parameters_with_defaults.empty())
|
||||
strings::StrAppend(¶meters_with_defaults, ", ");
|
||||
|
||||
strings::StrAppend(¶meters, param_and_default.first.GetRenameTo());
|
||||
strings::StrAppend(¶meters_with_defaults,
|
||||
param_and_default.first.GetRenameTo(), "=",
|
||||
param_and_default.first.GetRenameTo());
|
||||
if (type_annotations.find(param_and_default.first.GetName()) !=
|
||||
type_annotations.end()) {
|
||||
const string param_type =
|
||||
type_annotations.at(param_and_default.first.GetName());
|
||||
// Append to parameters and parameters_with_defaults because multiple
|
||||
// functions are generated by AddEagerFastPathAndGraphCode() and
|
||||
// AddEagerFallbackCode()
|
||||
strings::StrAppend(¶meters, ": ", param_type);
|
||||
strings::StrAppend(¶meters_with_defaults, ":", param_type);
|
||||
}
|
||||
|
||||
strings::StrAppend(¶meters_with_defaults, "=",
|
||||
param_and_default.second);
|
||||
}
|
||||
|
||||
@ -356,18 +421,108 @@ string GenEagerPythonOp::Code() {
|
||||
string eager_not_allowed_error = GetEagerNotAllowedError();
|
||||
|
||||
if (!AddEagerFastPathAndGraphCode(parameters_with_defaults, output_sizes,
|
||||
eager_not_allowed_error)) {
|
||||
eager_not_allowed_error,
|
||||
type_annotations)) {
|
||||
return result_;
|
||||
}
|
||||
|
||||
if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
|
||||
eager_not_allowed_error)) {
|
||||
eager_not_allowed_error, type_annotations)) {
|
||||
return result_;
|
||||
}
|
||||
|
||||
return prelude_ + result_;
|
||||
}
|
||||
|
||||
std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
|
||||
std::unordered_map<string, string> type_annotations;
|
||||
// Map attrs to TypeVars
|
||||
for (const auto& attr : op_def_.attr()) {
|
||||
if (attr.type() == "type") {
|
||||
const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name();
|
||||
type_annotations[attr.name()] = type_var_name;
|
||||
} else if (attr.type() == "bool" || attr.type() == "float" ||
|
||||
attr.type() == "int" || attr.type() == "bytes") {
|
||||
type_annotations[attr.name()] = attr.type();
|
||||
} else if (attr.type() == "string") {
|
||||
type_annotations[attr.name()] = "str";
|
||||
}
|
||||
}
|
||||
|
||||
// Map input Tensors to their types
|
||||
for (const auto& arg : op_def_.input_arg()) {
|
||||
// TODO(rahulkamat): Add type annotations to args that accept a sequence of
|
||||
// Tensors
|
||||
if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
|
||||
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
|
||||
}
|
||||
|
||||
// TODO(rahulkamat): Add type annotations to handle return types of a sequence
|
||||
// of Tensors. Map output Tensor to its type
|
||||
if (op_def_.output_arg_size() == 1) {
|
||||
const auto& arg = op_def_.output_arg(0);
|
||||
if (arg.number_attr().empty() && arg.type_list_attr().empty())
|
||||
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
|
||||
}
|
||||
|
||||
return type_annotations;
|
||||
}
|
||||
|
||||
// Generate TypeVars using attrs
|
||||
void GenEagerPythonOp::GenerateTypeVars(
|
||||
const std::unordered_map<string, string>& type_annotations) {
|
||||
bool added_typevar = false;
|
||||
for (const auto& attr : op_def_.attr()) {
|
||||
if (attr.type() == "type") {
|
||||
std::vector<string> allowed_types;
|
||||
for (int t : attr.allowed_values().list().type()) {
|
||||
DataType dtype = static_cast<DataType>(t);
|
||||
const string py_dtype =
|
||||
python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
|
||||
allowed_types.emplace_back(dtype_type.at(py_dtype));
|
||||
}
|
||||
|
||||
// When a Tensor does not have any dtypes specified, all dtypes are
|
||||
// allowed
|
||||
if (allowed_types.empty()) {
|
||||
for (std::pair<string, string> map_dtype : dtype_type) {
|
||||
allowed_types.emplace_back(map_dtype.second);
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(allowed_types.begin(), allowed_types.end());
|
||||
|
||||
string typevar_dtypes;
|
||||
for (std::vector<string>::iterator it = allowed_types.begin();
|
||||
it != allowed_types.end(); ++it) {
|
||||
if (!typevar_dtypes.empty()) strings::StrAppend(&typevar_dtypes, ", ");
|
||||
strings::StrAppend(&typevar_dtypes, *it);
|
||||
}
|
||||
|
||||
const string type_var_name = type_annotations.at(attr.name());
|
||||
strings::StrAppend(&result_, type_var_name, " = TypeVar(\"",
|
||||
type_var_name, "\", ", typevar_dtypes, ")\n");
|
||||
added_typevar = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (added_typevar) strings::StrAppend(&result_, "\n");
|
||||
}
|
||||
|
||||
void GenEagerPythonOp::AddReturnTypeAnnotation(
|
||||
const std::unordered_map<string, string>& type_annotations) {
|
||||
if (op_def_.output_arg_size() == 1) {
|
||||
const auto& arg = op_def_.output_arg(0);
|
||||
if (arg.number_attr().empty() && arg.type_list_attr().empty()) {
|
||||
const string return_type = type_annotations.at(arg.name());
|
||||
// TODO(rahulkamat): Modify AddDefLine() to add return type annotation to
|
||||
// avoid erasing ":\n" from the end of the def line
|
||||
result_.erase(result_.length() - 2);
|
||||
strings::StrAppend(&result_, " -> ", return_type, ":\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GenEagerPythonOp::HandleGraphMode(
|
||||
const string& function_setup, const std::vector<string>& output_sizes) {
|
||||
strings::StrAppend(&result_, " # Add nodes to the TensorFlow graph.\n");
|
||||
@ -690,13 +845,20 @@ void GenEagerPythonOp::AddEagerFunctionTeardown(
|
||||
|
||||
bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
||||
const string& parameters, const std::vector<string>& output_sizes,
|
||||
const string& eager_not_allowed_error) {
|
||||
const string& eager_not_allowed_error,
|
||||
const std::unordered_map<string, string>& type_annotations) {
|
||||
if (add_type_annotations_) {
|
||||
GenerateTypeVars(type_annotations);
|
||||
}
|
||||
if (api_def_.visibility() == ApiDef::VISIBLE) {
|
||||
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
|
||||
}
|
||||
|
||||
AddExport();
|
||||
AddDefLine(function_name_, parameters);
|
||||
if (add_type_annotations_) {
|
||||
AddReturnTypeAnnotation(type_annotations);
|
||||
}
|
||||
AddDocStringDescription();
|
||||
AddDocStringArgs();
|
||||
AddDocStringInputs();
|
||||
@ -731,11 +893,14 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
||||
|
||||
bool GenEagerPythonOp::AddEagerFallbackCode(
|
||||
const string& parameters, const std::vector<string>& output_sizes,
|
||||
const string& num_outputs_expr, const string& eager_not_allowed_error) {
|
||||
const string& num_outputs_expr, const string& eager_not_allowed_error,
|
||||
const std::unordered_map<string, string>& type_annotations) {
|
||||
AddDefLine(
|
||||
strings::StrCat(function_name_, kEagerFallbackSuffix),
|
||||
strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
|
||||
|
||||
if (add_type_annotations_) {
|
||||
AddReturnTypeAnnotation(type_annotations);
|
||||
}
|
||||
if (!eager_not_allowed_error.empty()) {
|
||||
strings::StrAppend(&result_, " ", eager_not_allowed_error);
|
||||
return true;
|
||||
@ -982,9 +1147,10 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
|
||||
function_name_, "))\n");
|
||||
}
|
||||
|
||||
string GetPythonOpsImpl(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name = "") {
|
||||
string GetPythonOpsImpl(
|
||||
const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops, const string& source_file_name = "",
|
||||
const std::unordered_set<string> type_annotate_ops = {}) {
|
||||
string result;
|
||||
// Header
|
||||
// TODO(josh11b): Mention the library for which wrappers are being generated.
|
||||
@ -1018,6 +1184,7 @@ from tensorflow.python.util.deprecation import deprecated_endpoints
|
||||
from tensorflow.python.util import dispatch as _dispatch
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
from typing import TypeVar
|
||||
)");
|
||||
|
||||
for (const auto& op_def : ops.op()) {
|
||||
@ -1061,8 +1228,12 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
continue;
|
||||
}
|
||||
|
||||
auto iter = type_annotate_ops.find(op_def.name());
|
||||
bool add_type_annotations = iter != type_annotate_ops.end();
|
||||
|
||||
strings::StrAppend(&result,
|
||||
GetEagerPythonOp(op_def, *api_def, function_name));
|
||||
GetEagerPythonOp(op_def, *api_def, function_name,
|
||||
add_type_annotations));
|
||||
}
|
||||
|
||||
return result;
|
||||
@ -1072,15 +1243,19 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name) {
|
||||
return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name);
|
||||
const string& source_file_name,
|
||||
const std::unordered_set<string> type_annotate_ops) {
|
||||
return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
|
||||
type_annotate_ops);
|
||||
}
|
||||
|
||||
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name) {
|
||||
printf("%s",
|
||||
GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name).c_str());
|
||||
const string& source_file_name,
|
||||
const std::unordered_set<string> type_annotate_ops) {
|
||||
printf("%s", GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
|
||||
type_annotate_ops)
|
||||
.c_str());
|
||||
}
|
||||
|
||||
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||
@ -1091,4 +1266,20 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||
return GetPythonOpsImpl(ops, api_def_map, {});
|
||||
}
|
||||
|
||||
string GetArgAnnotation(
|
||||
const OpDef::ArgDef& arg,
|
||||
const std::unordered_map<string, string>& type_annotations) {
|
||||
if (!arg.type_attr().empty()) {
|
||||
// Get the correct TypeVar if arg maps to an attr
|
||||
return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
|
||||
} else {
|
||||
// Get the dtype of the Tensor
|
||||
const string py_dtype =
|
||||
python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
|
||||
return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
|
||||
}
|
||||
|
||||
return "Any";
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,7 +16,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -32,7 +34,8 @@ namespace tensorflow {
|
||||
// file where the ops' REGISTER_OP() calls reside.
|
||||
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name);
|
||||
const string& source_file_name,
|
||||
const std::unordered_set<string> type_annotate_ops);
|
||||
|
||||
// Prints the output of GetPrintOps to stdout.
|
||||
// hidden_ops should be a list of Op names that should get a leading _
|
||||
@ -41,7 +44,8 @@ string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
// where the ops' REGISTER_OP() calls reside.
|
||||
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name);
|
||||
const string& source_file_name,
|
||||
const std::unordered_set<string> type_annotate_ops);
|
||||
|
||||
// Get the python wrappers for a list of ops in a OpList.
|
||||
// `op_list_buf` should be a pointer to a buffer containing
|
||||
@ -49,6 +53,13 @@ void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
// length of that buffer.
|
||||
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len);
|
||||
|
||||
// Get the type annotation for an arg
|
||||
// `arg` should be an input or output of an op
|
||||
// `type_annotations` should contain attr names mapped to TypeVar names
|
||||
string GetArgAnnotation(
|
||||
const OpDef::ArgDef& arg,
|
||||
const std::unordered_map<string, string>& type_annotations);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
|
||||
|
@ -513,10 +513,11 @@ const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
|
||||
}
|
||||
|
||||
GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
|
||||
const string& function_name)
|
||||
const string& function_name, bool add_type_annotations)
|
||||
: op_def_(op_def),
|
||||
api_def_(api_def),
|
||||
function_name_(function_name),
|
||||
add_type_annotations_(add_type_annotations),
|
||||
num_outs_(op_def.output_arg_size()) {}
|
||||
|
||||
GenPythonOp::~GenPythonOp() {}
|
||||
|
@ -71,7 +71,7 @@ class ParamNames {
|
||||
class GenPythonOp {
|
||||
public:
|
||||
GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
|
||||
const string& function_name);
|
||||
const string& function_name, bool add_type_annotations_);
|
||||
virtual ~GenPythonOp();
|
||||
|
||||
virtual string Code();
|
||||
@ -98,6 +98,7 @@ class GenPythonOp {
|
||||
const OpDef& op_def_;
|
||||
const ApiDef& api_def_;
|
||||
const string function_name_;
|
||||
bool add_type_annotations_;
|
||||
const int num_outs_;
|
||||
|
||||
// Return value from Code() is prelude_ + result_.
|
||||
|
@ -108,7 +108,8 @@ string InferSourceFileName(const char* argv_zero) {
|
||||
void PrintAllPythonOps(const std::vector<string>& op_list,
|
||||
const std::vector<string>& api_def_dirs,
|
||||
const string& source_file_name,
|
||||
bool op_list_is_whitelist) {
|
||||
bool op_list_is_whitelist,
|
||||
const std::unordered_set<string> type_annotate_ops) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(false, &ops);
|
||||
|
||||
@ -133,9 +134,11 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
|
||||
*pruned_ops.mutable_op()->Add() = op_def;
|
||||
}
|
||||
}
|
||||
PrintPythonOps(pruned_ops, api_def_map, {}, source_file_name);
|
||||
PrintPythonOps(pruned_ops, api_def_map, {}, source_file_name,
|
||||
type_annotate_ops);
|
||||
} else {
|
||||
PrintPythonOps(ops, api_def_map, op_list, source_file_name);
|
||||
PrintPythonOps(ops, api_def_map, op_list, source_file_name,
|
||||
type_annotate_ops);
|
||||
}
|
||||
}
|
||||
|
||||
@ -157,19 +160,25 @@ int main(int argc, char* argv[]) {
|
||||
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
|
||||
argv[1], ",", tensorflow::str_util::SkipEmpty());
|
||||
|
||||
// Add op name here to generate type annotations for it
|
||||
const std::unordered_set<tensorflow::string> type_annotate_ops{};
|
||||
|
||||
if (argc == 2) {
|
||||
tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
|
||||
false /* op_list_is_whitelist */);
|
||||
false /* op_list_is_whitelist */,
|
||||
type_annotate_ops);
|
||||
} else if (argc == 3) {
|
||||
std::vector<tensorflow::string> hidden_ops;
|
||||
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
|
||||
tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
|
||||
false /* op_list_is_whitelist */);
|
||||
false /* op_list_is_whitelist */,
|
||||
type_annotate_ops);
|
||||
} else if (argc == 4) {
|
||||
std::vector<tensorflow::string> op_list;
|
||||
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
|
||||
tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
|
||||
tensorflow::string(argv[3]) == "1");
|
||||
tensorflow::string(argv[3]) == "1",
|
||||
type_annotate_ops);
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
|
@ -23,20 +23,458 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(PythonOpGen, Basic) {
|
||||
void ExpectHasSubstr(const string& s, const string& expected) {
|
||||
EXPECT_TRUE(absl::StrContains(s, expected))
|
||||
<< "'Generated ops does not contain '" << expected << "'";
|
||||
}
|
||||
|
||||
void ExpectDoesNotHaveSubstr(const string& s, const string& expected) {
|
||||
EXPECT_FALSE(absl::StrContains(s, expected))
|
||||
<< "'Generated ops contains '" << expected << "'";
|
||||
}
|
||||
|
||||
void ExpectSubstrOrder(const string& s, const string& before,
|
||||
const string& after) {
|
||||
int before_pos = s.find(before);
|
||||
int after_pos = s.find(after);
|
||||
ASSERT_NE(std::string::npos, before_pos);
|
||||
ASSERT_NE(std::string::npos, after_pos);
|
||||
EXPECT_LT(before_pos, after_pos) << before << "' is not before '" << after;
|
||||
}
|
||||
|
||||
TEST(PythonOpGen, TypeAnnotateAllOps) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(false, &ops);
|
||||
|
||||
ApiDefMap api_def_map(ops);
|
||||
|
||||
string code = GetPythonOps(ops, api_def_map, {}, "");
|
||||
std::unordered_set<string> type_annotate_ops;
|
||||
for (const auto& op : ops.op()) {
|
||||
type_annotate_ops.insert(op.name());
|
||||
}
|
||||
|
||||
EXPECT_TRUE(absl::StrContains(code, "def case"));
|
||||
string code = GetPythonOps(ops, api_def_map, {}, "", type_annotate_ops);
|
||||
|
||||
// TODO(mdan): Add tests to verify type annotations are correctly added.
|
||||
const string all_types =
|
||||
", _dtypes.BFloat16, _dtypes.Bool, _dtypes.Complex128, "
|
||||
"_dtypes.Complex64, "
|
||||
"_dtypes.Float16, _dtypes.Float32, _dtypes.Float64, _dtypes.Half, "
|
||||
"_dtypes.Int16, "
|
||||
"_dtypes.Int32, _dtypes.Int64, _dtypes.Int8, _dtypes.QInt16, "
|
||||
"_dtypes.QInt32, "
|
||||
"_dtypes.QInt8, _dtypes.QUInt16, _dtypes.QUInt8, _dtypes.Resource, "
|
||||
"_dtypes.String, "
|
||||
"_dtypes.UInt16, _dtypes.UInt32, _dtypes.UInt64, _dtypes.UInt8, "
|
||||
"_dtypes.Variant)";
|
||||
|
||||
const string fake_param_typevar =
|
||||
"TV_FakeParam_dtype = TypeVar(\"TV_FakeParam_dtype\"" + all_types;
|
||||
const string fake_param =
|
||||
"def fake_param_eager_fallback(dtype: TV_FakeParam_dtype, shape, name, "
|
||||
"ctx) -> _ops.Tensor[TV_FakeParam_dtype]:";
|
||||
const string fake_param_fallback =
|
||||
"def fake_param_eager_fallback(dtype: TV_FakeParam_dtype, shape, name, "
|
||||
"ctx) -> _ops.Tensor[TV_FakeParam_dtype]:";
|
||||
|
||||
ExpectHasSubstr(code, fake_param_typevar);
|
||||
ExpectHasSubstr(code, fake_param);
|
||||
ExpectHasSubstr(code, fake_param_fallback);
|
||||
|
||||
const string to_bool_typevar =
|
||||
"TV_ToBool_T = TypeVar(\"TV_ToBool_T\"" + all_types;
|
||||
const string to_bool_ =
|
||||
"def to_bool(input: _ops.Tensor[TV_ToBool_T], name=None) -> "
|
||||
"_ops.Tensor[_dtypes.Bool]:";
|
||||
const string to_bool_fallback =
|
||||
"def to_bool_eager_fallback(input: _ops.Tensor[TV_ToBool_T], name, ctx) "
|
||||
"-> _ops.Tensor[_dtypes.Bool]:";
|
||||
|
||||
ExpectHasSubstr(code, to_bool_typevar);
|
||||
ExpectHasSubstr(code, to_bool_);
|
||||
ExpectHasSubstr(code, to_bool_fallback);
|
||||
}
|
||||
|
||||
// TODO(mdan): Include more tests with synhtetic ops and api defs.
|
||||
TEST(PythonOpGen, TypeAnnotateSingleTypeTensor) {
|
||||
constexpr char kBaseOpDef[] = R"(
|
||||
op {
|
||||
name: "Bar"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type: DT_QINT8
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type: DT_BOOL
|
||||
}
|
||||
summary: "Summary for op Bar."
|
||||
description: "Description for op Bar."
|
||||
}
|
||||
)";
|
||||
|
||||
std::unordered_set<string> type_annotate_ops{"Bar"};
|
||||
|
||||
OpList op_defs;
|
||||
OpRegistry::Global()->Export(false, &op_defs);
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
|
||||
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
|
||||
|
||||
const string typed_bar =
|
||||
"def bar(x: _ops.Tensor[_dtypes.String], y: _ops.Tensor[_dtypes.QInt8], "
|
||||
"name=None) -> _ops.Tensor[_dtypes.Bool]:";
|
||||
ExpectHasSubstr(code, typed_bar);
|
||||
|
||||
const string untyped_bar = "def bar(x, y, name=None):";
|
||||
ExpectDoesNotHaveSubstr(code, untyped_bar);
|
||||
}
|
||||
|
||||
TEST(PythonOpGen, TypeAnnotateMultiTypeTensor) {
|
||||
constexpr char kBaseOpDef[] = R"(
|
||||
op {
|
||||
name: "Foo"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type_attr: "T2"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_UINT8
|
||||
type: DT_INT8
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T2"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_STRING
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Summary for op Foo."
|
||||
description: "Description for op Foo."
|
||||
}
|
||||
)";
|
||||
|
||||
std::unordered_set<string> type_annotate_ops{
|
||||
"Foo",
|
||||
};
|
||||
|
||||
OpList op_defs;
|
||||
OpRegistry::Global()->Export(false, &op_defs);
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
|
||||
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
|
||||
|
||||
const string typed_foo =
|
||||
"def foo(x: _ops.Tensor[TV_Foo_T], y: _ops.Tensor[TV_Foo_T2], name=None) "
|
||||
"-> _ops.Tensor[TV_Foo_T]:";
|
||||
ExpectHasSubstr(code, typed_foo);
|
||||
}
|
||||
|
||||
TEST(PythonOpGen, GenerateCorrectTypeVars) {
|
||||
constexpr char kBaseOpDef[] = R"(
|
||||
op {
|
||||
name: "Foo"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type_attr: "T2"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_UINT8
|
||||
type: DT_INT8
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T2"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_STRING
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Summary for op Foo."
|
||||
description: "Description for op Foo."
|
||||
}
|
||||
)";
|
||||
|
||||
std::unordered_set<string> type_annotate_ops{
|
||||
"Foo",
|
||||
};
|
||||
|
||||
OpList op_defs;
|
||||
OpRegistry::Global()->Export(false, &op_defs);
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
|
||||
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
|
||||
|
||||
const string typevars_foo = R"(
|
||||
TV_Foo_T = TypeVar("TV_Foo_T", _dtypes.Int8, _dtypes.UInt8)
|
||||
TV_Foo_T2 = TypeVar("TV_Foo_T2", _dtypes.Float32, _dtypes.Float64, _dtypes.String)
|
||||
)";
|
||||
|
||||
ExpectHasSubstr(code, typevars_foo);
|
||||
}
|
||||
|
||||
TEST(PythonOpGen, TypeAnnotateFallback) {
|
||||
constexpr char kBaseOpDef[] = R"(
|
||||
op {
|
||||
name: "Foo"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type_attr: "T2"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_UINT8
|
||||
type: DT_INT8
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T2"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_STRING
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Summary for op Foo."
|
||||
description: "Description for op Foo."
|
||||
}
|
||||
)";
|
||||
|
||||
std::unordered_set<string> type_annotate_ops{
|
||||
"Foo",
|
||||
};
|
||||
|
||||
OpList op_defs;
|
||||
OpRegistry::Global()->Export(false, &op_defs);
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
|
||||
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
|
||||
|
||||
const string typed_foo_fallback =
|
||||
"def foo_eager_fallback(x: _ops.Tensor[TV_Foo_T], y: "
|
||||
"_ops.Tensor[TV_Foo_T2], name, ctx) -> _ops.Tensor[TV_Foo_T]:";
|
||||
ExpectHasSubstr(code, typed_foo_fallback);
|
||||
}
|
||||
|
||||
TEST(PythonOpGen, GenerateTypeVarAboveOp) {
|
||||
constexpr char kBaseOpDef[] = R"(
|
||||
op {
|
||||
name: "Foo"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type_attr: "T2"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_UINT8
|
||||
type: DT_INT8
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T2"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_STRING
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Summary for op Foo."
|
||||
description: "Description for op Foo."
|
||||
}
|
||||
)";
|
||||
|
||||
std::unordered_set<string> type_annotate_ops{
|
||||
"Foo",
|
||||
};
|
||||
|
||||
OpList op_defs;
|
||||
OpRegistry::Global()->Export(false, &op_defs);
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
|
||||
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
|
||||
|
||||
const string typevar_foo = "TV_Foo_";
|
||||
const string def_foo = "def foo";
|
||||
ExpectSubstrOrder(code, typevar_foo, def_foo);
|
||||
}
|
||||
|
||||
TEST(PythonOpGen, TypeAnnotateDefaultParams) {
|
||||
constexpr char kBaseOpDef[] = R"(
|
||||
op {
|
||||
name: "FooBar"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type: DT_BOOL
|
||||
}
|
||||
attr {
|
||||
name: "t"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_HALF
|
||||
type: DT_INT8
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "var1"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "var2"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
summary: "Summary for op FooBar."
|
||||
description: "Description for op FooBar."
|
||||
}
|
||||
)";
|
||||
|
||||
std::unordered_set<string> type_annotate_ops{
|
||||
"FooBar",
|
||||
};
|
||||
|
||||
OpList op_defs;
|
||||
OpRegistry::Global()->Export(false, &op_defs);
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
|
||||
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
|
||||
|
||||
const string params =
|
||||
"def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, "
|
||||
"var1:bool=False, var2:int=0, name=None)";
|
||||
const string params_fallback =
|
||||
"def foo_bar_eager_fallback(x: _ops.Tensor[_dtypes.Float32], t: "
|
||||
"TV_FooBar_t, var1: bool, var2: int, name, ctx)";
|
||||
ExpectHasSubstr(code, params);
|
||||
ExpectHasSubstr(code, params_fallback);
|
||||
}
|
||||
|
||||
TEST(PythonOpGen, NoTypingSequenceTensors) {
|
||||
constexpr char kBaseOpDef[] = R"(
|
||||
op {
|
||||
name: "Baz"
|
||||
input_arg {
|
||||
name: "inputs"
|
||||
number_attr: "N"
|
||||
type_list_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output1"
|
||||
type: DT_BOOL
|
||||
}
|
||||
output_arg {
|
||||
name: "output2"
|
||||
type: DT_BOOL
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "bool"
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
}
|
||||
summary: "Summary for op Baz."
|
||||
description: "Description for op Baz."
|
||||
}
|
||||
)";
|
||||
|
||||
std::unordered_set<string> type_annotate_ops{"Baz"};
|
||||
|
||||
OpList op_defs;
|
||||
OpRegistry::Global()->Export(false, &op_defs);
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
|
||||
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
|
||||
|
||||
const string baz_def_line = "def baz(inputs, name=None):";
|
||||
|
||||
ExpectHasSubstr(code, baz_def_line);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user