diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 0b6f974d962..b6e39a4df00 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -45,7 +45,8 @@ const int kRightMargin = 78; constexpr char kEagerFallbackSuffix[] = "_eager_fallback"; -std::unordered_map<string, string> dtype_type { +// Dtype enums mapped to dtype classes which is the type of each dtype +const std::unordered_map<string, string> dtype_type { {"_dtypes.float16", "_dtypes.Float16"}, {"_dtypes.half", "_dtypes.Half"}, {"_dtypes.float32", "_dtypes.Float32"}, @@ -133,8 +134,8 @@ string TensorPBString(const TensorProto& pb) { class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { public: GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, - const string& function_name, const bool type_annotate_op) - : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name, type_annotate_op) { + const string& function_name, bool add_type_annotations) + : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name, add_type_annotations) { op_name_ = function_name_; absl::ConsumePrefix(&op_name_, "_"); } @@ -160,12 +161,12 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { bool AddEagerFastPathAndGraphCode(const string& parameters, const std::vector<string>& output_sizes, const string& eager_not_allowed_error, - std::unordered_map<string, string>& type_annotations); + const std::unordered_map<string, string>& type_annotations); bool AddEagerFallbackCode(const string& parameters, const std::vector<string>& output_sizes, const string& num_outputs_expr, const string& eager_not_allowed_error, - std::unordered_map<string, string>& type_annotations); + const std::unordered_map<string, string>& type_annotations); void AddEagerFastPathExecute(); void AddEagerInferredAttrs(const string& indentation); @@ -177,11 +178,11 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { void AddRawOpExport(const string& parameters); - std::unordered_map<string, string> GetTypeAnnotationMap(); + std::unordered_map<string, string> GetTypeAnnotations(); - void GenerateTypeVars(std::unordered_map<string, string>& type_annotations); + void GenerateTypeVars(const std::unordered_map<string, string>& type_annotations); - void AddReturnTypeAnnotation(std::unordered_map<string, string>& type_annotations); + void AddReturnTypeAnnotation(const std::unordered_map<string, string>& type_annotations); void AddAttrForArg(const string& attr, int arg_index) { gtl::InsertIfNotPresent(&inferred_attrs_, attr, @@ -214,8 +215,8 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { }; string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, - const string& function_name, const bool type_annotate_op) { - return GenEagerPythonOp(op_def, api_def, function_name, type_annotate_op).Code(); + const string& function_name, bool add_type_annotations) { + return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations).Code(); } string GenEagerPythonOp::FlattenInputs( @@ -347,8 +348,8 @@ string GenEagerPythonOp::Code() { std::unordered_map<string, string> type_annotations; // Only populate map for whitelisted ops - if (type_annotate_op_) { - type_annotations = GetTypeAnnotationMap(); + if (add_type_annotations_) { + type_annotations = GetTypeAnnotations(); } string parameters; @@ -357,33 +358,28 @@ string GenEagerPythonOp::Code() { if (!parameters.empty()) strings::StrAppend(¶meters, ", "); strings::StrAppend(¶meters, param.GetRenameTo()); - // Add type annotations to param if (type_annotations.find(param.GetName()) != type_annotations.end()) { - strings::StrAppend(¶meters, ": ", type_annotations[param.GetName()]); + strings::StrAppend(¶meters, ": ", type_annotations.at(param.GetName())); } } - // Append to parameters and parameters_with_defaults because multiple functions - // are generated (op and fallback op) string parameters_with_defaults = parameters; for (const auto& param_and_default : params_with_default_) { if (!parameters.empty()) strings::StrAppend(¶meters, ", "); if (!parameters_with_defaults.empty()) strings::StrAppend(¶meters_with_defaults, ", "); - // Add type annotations to param_and_default + strings::StrAppend(¶meters, param_and_default.first.GetRenameTo()); + strings::StrAppend(¶meters_with_defaults, param_and_default.first.GetRenameTo()); if (type_annotations.find(param_and_default.first.GetName()) != type_annotations.end()) { - const string param_type = type_annotations[param_and_default.first.GetName()]; - strings::StrAppend(¶meters, param_and_default.first.GetRenameTo(), ": ", param_type); - strings::StrAppend(¶meters_with_defaults, - param_and_default.first.GetRenameTo(), ": ", - param_type, " = ", param_and_default.second); - continue; + const string param_type = type_annotations.at(param_and_default.first.GetName()); + // Append to parameters and parameters_with_defaults because multiple functions + // are generated by AddEagerFastPathAndGraphCode() and AddEagerFallbackCode() + strings::StrAppend(¶meters, ": ", param_type); + strings::StrAppend(¶meters_with_defaults, ":", param_type); } - strings::StrAppend(¶meters, param_and_default.first.GetRenameTo()); - strings::StrAppend(¶meters_with_defaults, - param_and_default.first.GetRenameTo(), "=", + strings::StrAppend(¶meters_with_defaults, "=", param_and_default.second); } @@ -428,9 +424,9 @@ string GenEagerPythonOp::Code() { return prelude_ + result_; } -std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotationMap() { +std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() { std::unordered_map<string, string> type_annotations; - // Mapping attrs to TypeVars + // Map attrs to TypeVars for (const auto& attr : op_def_.attr()) { if (attr.type() == "type") { const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name(); @@ -441,24 +437,26 @@ std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotationMap() { } } - // Mapping input Tensors to their types + // Map input Tensors to their types for (const auto& arg : op_def_.input_arg()) { - // Do not add type annotations to args that accept a sequence of Tensors - if (!arg.number_attr().empty()) continue; + // TODO(rahulkamat): Add type annotations to args that accept a sequence of Tensors + if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue; type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations); } - // Mapping output Tensor to its type + // TODO(rahulkamat): Add type annotations to handle return types of a sequence of Tensors. + // Map output Tensor to its type if (op_def_.output_arg_size() == 1) { const auto& arg = op_def_.output_arg(0); - type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations); + if (arg.number_attr().empty() && arg.type_list_attr().empty()) + type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations); } return type_annotations; } // Generate TypeVars using attrs -void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type_annotations) { +void GenEagerPythonOp::GenerateTypeVars(const std::unordered_map<string, string>& type_annotations) { bool added_typevar = false; for (const auto& attr : op_def_.attr()) { if (attr.type() == "type") { @@ -466,12 +464,10 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type for (int t : attr.allowed_values().list().type()) { DataType dtype = static_cast<DataType>(t); const string py_dtype = python_op_gen_internal::DataTypeToPython(dtype, "_dtypes."); - if (dtype_type.find(py_dtype) != dtype_type.end()) { - allowed_types.emplace_back(dtype_type[py_dtype]); - } + allowed_types.emplace_back(dtype_type.at(py_dtype)); } - // If all dtypes are allowed, add them all + // When a Tensor does not have any dtypes specified, all dtypes are allowed if (allowed_types.empty()) { for (std::pair<string, string> map_dtype : dtype_type) { allowed_types.emplace_back(map_dtype.second); @@ -486,7 +482,7 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type strings::StrAppend(&typevar_dtypes, *it); } - const string type_var_name = type_annotations[attr.name()]; + const string type_var_name = type_annotations.at(attr.name()); strings::StrAppend(&result_, type_var_name, " = TypeVar(\"", type_var_name, "\", ", typevar_dtypes,")\n"); added_typevar = true; } @@ -495,14 +491,15 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type if (added_typevar) strings::StrAppend(&result_, "\n"); } -// TODO(rahulkamat): Modify AddDefLine() to add return type annotation -void GenEagerPythonOp::AddReturnTypeAnnotation(std::unordered_map<string, string>& type_annotations) { +void GenEagerPythonOp::AddReturnTypeAnnotation(const std::unordered_map<string, string>& type_annotations) { if (op_def_.output_arg_size() == 1) { const auto& arg = op_def_.output_arg(0); - // Add type annotations to param - if (type_annotations.find(arg.name()) != type_annotations.end()) { + if (arg.number_attr().empty() && arg.type_list_attr().empty()) { + const string return_type = type_annotations.at(arg.name()); + // TODO(rahulkamat): Modify AddDefLine() to add return type annotation to avoid + // erasing ":\n" from the end of the def line result_.erase(result_.length() - 2); - strings::StrAppend(&result_, " -> ", type_annotations[arg.name()], ":\n"); + strings::StrAppend(&result_, " -> ", return_type, ":\n"); } } } @@ -829,8 +826,9 @@ void GenEagerPythonOp::AddEagerFunctionTeardown( bool GenEagerPythonOp::AddEagerFastPathAndGraphCode( const string& parameters, const std::vector<string>& output_sizes, - const string& eager_not_allowed_error, std::unordered_map<string, string>& type_annotations) { - if (type_annotate_op_) { + const string& eager_not_allowed_error, + const std::unordered_map<string, string>& type_annotations) { + if (add_type_annotations_) { GenerateTypeVars(type_annotations); } if (api_def_.visibility() == ApiDef::VISIBLE) { @@ -839,7 +837,7 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode( AddExport(); AddDefLine(function_name_, parameters); - if (type_annotate_op_) { + if (add_type_annotations_) { AddReturnTypeAnnotation(type_annotations); } AddDocStringDescription(); @@ -877,11 +875,11 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode( bool GenEagerPythonOp::AddEagerFallbackCode( const string& parameters, const std::vector<string>& output_sizes, const string& num_outputs_expr, const string& eager_not_allowed_error, - std::unordered_map<string, string>& type_annotations) { + const std::unordered_map<string, string>& type_annotations) { AddDefLine( strings::StrCat(function_name_, kEagerFallbackSuffix), strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx")); - if (type_annotate_op_) { + if (add_type_annotations_) { AddReturnTypeAnnotation(type_annotations); } if (!eager_not_allowed_error.empty()) { @@ -1133,7 +1131,7 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) { string GetPythonOpsImpl(const OpList& ops, const ApiDefMap& api_defs, const std::vector<string>& hidden_ops, const string& source_file_name = "", - std::unordered_set<string> type_annotate_ops = {}) { + const std::unordered_set<string> type_annotate_ops = {}) { string result; // Header // TODO(josh11b): Mention the library for which wrappers are being generated. @@ -1211,10 +1209,11 @@ from typing import TypeVar continue; } - const bool type_annotate_op = type_annotate_ops.find(op_def.name()) != type_annotate_ops.end(); + auto iter = type_annotate_ops.find(op_def.name()); + bool add_type_annotations = iter != type_annotate_ops.end(); strings::StrAppend(&result, - GetEagerPythonOp(op_def, *api_def, function_name, type_annotate_op)); + GetEagerPythonOp(op_def, *api_def, function_name, add_type_annotations)); } return result; @@ -1225,14 +1224,14 @@ from typing import TypeVar string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs, const std::vector<string>& hidden_ops, const string& source_file_name, - std::unordered_set<string> type_annotate_ops) { + const std::unordered_set<string> type_annotate_ops) { return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name, type_annotate_ops); } void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs, const std::vector<string>& hidden_ops, const string& source_file_name, - std::unordered_set<string> type_annotate_ops) { + const std::unordered_set<string> type_annotate_ops) { printf("%s", GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name, type_annotate_ops).c_str()); } @@ -1245,16 +1244,14 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) { return GetPythonOpsImpl(ops, api_def_map, {}); } -string GetArgAnnotation(const auto& arg, std::unordered_map<string, string>& type_annotations) { - if (type_annotations.find(arg.type_attr()) != type_annotations.end()) { +string GetArgAnnotation(const auto& arg, const std::unordered_map<string, string>& type_annotations) { + if (!arg.type_attr().empty()) { // Get the correct TypeVar if arg maps to an attr - return "_ops.Tensor[" + type_annotations[arg.type_attr()] + "]"; + return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]"; } else { // Get the dtype of the Tensor const string py_dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes."); - if (dtype_type.find(py_dtype) != dtype_type.end()) { - return "_ops.Tensor[" + dtype_type[py_dtype] + "]"; - } + return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]"; } return "Any"; diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h index 1a3b6c5e8f2..5dfc959b3ad 100644 --- a/tensorflow/python/framework/python_op_gen.h +++ b/tensorflow/python/framework/python_op_gen.h @@ -33,7 +33,7 @@ namespace tensorflow { string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs, const std::vector<string>& hidden_ops, const string& source_file_name, - std::unordered_set<string> type_annotate_ops); + const std::unordered_set<string> type_annotate_ops); // Prints the output of GetPrintOps to stdout. // hidden_ops should be a list of Op names that should get a leading _ @@ -43,7 +43,7 @@ string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs, void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs, const std::vector<string>& hidden_ops, const string& source_file_name, - std::unordered_set<string> type_annotate_ops); + const std::unordered_set<string> type_annotate_ops); // Get the python wrappers for a list of ops in a OpList. // `op_list_buf` should be a pointer to a buffer containing @@ -55,7 +55,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len); // `arg` should be an input or output of an op // `type_annotations` should contain attr names mapped to TypeVar names string GetArgAnnotation(const auto& arg, - std::unordered_map<string, string>& type_annotations); + const std::unordered_map<string, string>& type_annotations); } // namespace tensorflow diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index d0ef82857c4..adbdbbf06fb 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -513,11 +513,11 @@ const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) { } GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def, - const string& function_name, const bool type_annotate_op) + const string& function_name, bool add_type_annotations) : op_def_(op_def), api_def_(api_def), function_name_(function_name), - type_annotate_op_(type_annotate_op), + add_type_annotations_(add_type_annotations), num_outs_(op_def.output_arg_size()) {} GenPythonOp::~GenPythonOp() {} diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h index 5229bffc5d0..08d9b3c8a66 100644 --- a/tensorflow/python/framework/python_op_gen_internal.h +++ b/tensorflow/python/framework/python_op_gen_internal.h @@ -71,7 +71,7 @@ class ParamNames { class GenPythonOp { public: GenPythonOp(const OpDef& op_def, const ApiDef& api_def, - const string& function_name, const bool type_annotate_op_); + const string& function_name, bool add_type_annotations_); virtual ~GenPythonOp(); virtual string Code(); @@ -98,7 +98,7 @@ class GenPythonOp { const OpDef& op_def_; const ApiDef& api_def_; const string function_name_; - const bool type_annotate_op_; + bool add_type_annotations_; const int num_outs_; // Return value from Code() is prelude_ + result_. diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc index dcaea53100e..c3ef4202d2a 100644 --- a/tensorflow/python/framework/python_op_gen_main.cc +++ b/tensorflow/python/framework/python_op_gen_main.cc @@ -109,7 +109,7 @@ void PrintAllPythonOps(const std::vector<string>& op_list, const std::vector<string>& api_def_dirs, const string& source_file_name, bool op_list_is_whitelist, - std::unordered_set<string> type_annotate_ops) { + const std::unordered_set<string> type_annotate_ops) { OpList ops; OpRegistry::Global()->Export(false, &ops); @@ -159,7 +159,7 @@ int main(int argc, char* argv[]) { argv[1], ",", tensorflow::str_util::SkipEmpty()); // Add op name to this set to add type annotations - std::unordered_set<tensorflow::string> type_annotate_ops { + const std::unordered_set<tensorflow::string> type_annotate_ops { }; if (argc == 2) { diff --git a/tensorflow/python/framework/python_op_gen_test.cc b/tensorflow/python/framework/python_op_gen_test.cc index cf6566ea7ae..5fff1a1d111 100644 --- a/tensorflow/python/framework/python_op_gen_test.cc +++ b/tensorflow/python/framework/python_op_gen_test.cc @@ -261,7 +261,7 @@ TEST(PythonOpGen, TypeAnnotateDefaultParams) { string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops); - const string params = "def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1: bool = False, var2: int = 0, name=None)"; + const string params = "def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1:bool=False, var2:int=0, name=None)"; const string params_fallback = "def foo_bar_eager_fallback(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1: bool, var2: int, name, ctx)"; ExpectHasSubstr(code, params); ExpectHasSubstr(code, params_fallback);