diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index 0b6f974d962..b6e39a4df00 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -45,7 +45,8 @@ const int kRightMargin = 78;
 
 constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
 
-std::unordered_map<string, string> dtype_type {
+// Dtype enums mapped to dtype classes which is the type of each dtype
+const std::unordered_map<string, string> dtype_type {
       {"_dtypes.float16", "_dtypes.Float16"},
       {"_dtypes.half", "_dtypes.Half"},
       {"_dtypes.float32", "_dtypes.Float32"},
@@ -133,8 +134,8 @@ string TensorPBString(const TensorProto& pb) {
 class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
  public:
   GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
-                   const string& function_name, const bool type_annotate_op)
-      : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name, type_annotate_op) {
+                   const string& function_name, bool add_type_annotations)
+      : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name, add_type_annotations) {
     op_name_ = function_name_;
     absl::ConsumePrefix(&op_name_, "_");
   }
@@ -160,12 +161,12 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
   bool AddEagerFastPathAndGraphCode(const string& parameters,
                                     const std::vector<string>& output_sizes,
                                     const string& eager_not_allowed_error,
-                                    std::unordered_map<string, string>& type_annotations);
+                                    const std::unordered_map<string, string>& type_annotations);
   bool AddEagerFallbackCode(const string& parameters,
                             const std::vector<string>& output_sizes,
                             const string& num_outputs_expr,
                             const string& eager_not_allowed_error,
-                            std::unordered_map<string, string>& type_annotations);
+                            const std::unordered_map<string, string>& type_annotations);
   void AddEagerFastPathExecute();
 
   void AddEagerInferredAttrs(const string& indentation);
@@ -177,11 +178,11 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
 
   void AddRawOpExport(const string& parameters);
 
-  std::unordered_map<string, string> GetTypeAnnotationMap();
+  std::unordered_map<string, string> GetTypeAnnotations();
 
-  void GenerateTypeVars(std::unordered_map<string, string>& type_annotations);
+  void GenerateTypeVars(const std::unordered_map<string, string>& type_annotations);
 
-  void AddReturnTypeAnnotation(std::unordered_map<string, string>& type_annotations);
+  void AddReturnTypeAnnotation(const std::unordered_map<string, string>& type_annotations);
 
   void AddAttrForArg(const string& attr, int arg_index) {
     gtl::InsertIfNotPresent(&inferred_attrs_, attr,
@@ -214,8 +215,8 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
 };
 
 string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
-                        const string& function_name, const bool type_annotate_op) {
-  return GenEagerPythonOp(op_def, api_def, function_name, type_annotate_op).Code();
+                        const string& function_name, bool add_type_annotations) {
+  return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations).Code();
 }
 
 string GenEagerPythonOp::FlattenInputs(
@@ -347,8 +348,8 @@ string GenEagerPythonOp::Code() {
 
   std::unordered_map<string, string> type_annotations;
   // Only populate map for whitelisted ops
-  if (type_annotate_op_) {
-    type_annotations = GetTypeAnnotationMap();
+  if (add_type_annotations_) {
+    type_annotations = GetTypeAnnotations();
   }
 
   string parameters;
@@ -357,33 +358,28 @@ string GenEagerPythonOp::Code() {
     if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
     strings::StrAppend(&parameters, param.GetRenameTo());
 
-    // Add type annotations to param
     if (type_annotations.find(param.GetName()) != type_annotations.end()) {
-      strings::StrAppend(&parameters, ": ", type_annotations[param.GetName()]);
+      strings::StrAppend(&parameters, ": ", type_annotations.at(param.GetName()));
     }
   }
 
-  // Append to parameters and parameters_with_defaults because multiple functions
-  // are generated (op and fallback op)
   string parameters_with_defaults = parameters;
   for (const auto& param_and_default : params_with_default_) {
     if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
     if (!parameters_with_defaults.empty())
       strings::StrAppend(&parameters_with_defaults, ", ");
 
-    // Add type annotations to param_and_default
+    strings::StrAppend(&parameters, param_and_default.first.GetRenameTo());
+    strings::StrAppend(&parameters_with_defaults, param_and_default.first.GetRenameTo());
     if (type_annotations.find(param_and_default.first.GetName()) != type_annotations.end()) {
-      const string param_type = type_annotations[param_and_default.first.GetName()];
-      strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), ": ", param_type);
-      strings::StrAppend(&parameters_with_defaults,
-                         param_and_default.first.GetRenameTo(), ": ",
-                         param_type, " = ", param_and_default.second);
-      continue;
+      const string param_type = type_annotations.at(param_and_default.first.GetName());
+      // Append to parameters and parameters_with_defaults because multiple functions
+      // are generated by AddEagerFastPathAndGraphCode() and AddEagerFallbackCode()
+      strings::StrAppend(&parameters, ": ", param_type);
+      strings::StrAppend(&parameters_with_defaults, ":", param_type);
     }
 
-    strings::StrAppend(&parameters, param_and_default.first.GetRenameTo());
-    strings::StrAppend(&parameters_with_defaults,
-                       param_and_default.first.GetRenameTo(), "=",
+    strings::StrAppend(&parameters_with_defaults, "=",
                        param_and_default.second);
   }
 
@@ -428,9 +424,9 @@ string GenEagerPythonOp::Code() {
   return prelude_ + result_;
 }
 
-std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotationMap() {
+std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
   std::unordered_map<string, string> type_annotations;
-  // Mapping attrs to TypeVars
+  // Map attrs to TypeVars
   for (const auto& attr : op_def_.attr()) {
     if (attr.type() == "type") {
       const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name();
@@ -441,24 +437,26 @@ std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotationMap() {
     }
   }
 
-  // Mapping input Tensors to their types
+  // Map input Tensors to their types
   for (const auto& arg : op_def_.input_arg()) {
-    // Do not add type annotations to args that accept a sequence of Tensors
-    if (!arg.number_attr().empty()) continue;
+    // TODO(rahulkamat): Add type annotations to args that accept a sequence of Tensors
+    if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
     type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
   }
 
-  // Mapping output Tensor to its type
+  // TODO(rahulkamat): Add type annotations to handle return types of a sequence of Tensors.
+  // Map output Tensor to its type
   if (op_def_.output_arg_size() == 1) {
     const auto& arg = op_def_.output_arg(0);
-    type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
+    if (arg.number_attr().empty() && arg.type_list_attr().empty())
+      type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
   }
 
   return type_annotations;
 }
 
 // Generate TypeVars using attrs
-void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type_annotations) {
+void GenEagerPythonOp::GenerateTypeVars(const std::unordered_map<string, string>& type_annotations) {
   bool added_typevar = false;
   for (const auto& attr : op_def_.attr()) {
     if (attr.type() == "type") {
@@ -466,12 +464,10 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type
       for (int t : attr.allowed_values().list().type()) {
         DataType dtype = static_cast<DataType>(t);
         const string py_dtype = python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
-        if (dtype_type.find(py_dtype) != dtype_type.end()) {
-          allowed_types.emplace_back(dtype_type[py_dtype]);
-        }
+          allowed_types.emplace_back(dtype_type.at(py_dtype));
       }
 
-      // If all dtypes are allowed, add them all
+      // When a Tensor does not have any dtypes specified, all dtypes are allowed
       if (allowed_types.empty()) {
         for (std::pair<string, string> map_dtype : dtype_type) {
           allowed_types.emplace_back(map_dtype.second);
@@ -486,7 +482,7 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type
         strings::StrAppend(&typevar_dtypes, *it);
       }
 
-      const string type_var_name = type_annotations[attr.name()];
+      const string type_var_name = type_annotations.at(attr.name());
       strings::StrAppend(&result_, type_var_name, " = TypeVar(\"", type_var_name, "\", ", typevar_dtypes,")\n");
       added_typevar = true;
     }
@@ -495,14 +491,15 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type
   if (added_typevar) strings::StrAppend(&result_, "\n");
 }
 
-// TODO(rahulkamat): Modify AddDefLine() to add return type annotation
-void GenEagerPythonOp::AddReturnTypeAnnotation(std::unordered_map<string, string>& type_annotations) {
+void GenEagerPythonOp::AddReturnTypeAnnotation(const std::unordered_map<string, string>& type_annotations) {
   if (op_def_.output_arg_size() == 1) {
     const auto& arg = op_def_.output_arg(0);
-    // Add type annotations to param
-    if (type_annotations.find(arg.name()) != type_annotations.end()) {
+    if (arg.number_attr().empty() && arg.type_list_attr().empty()) {
+      const string return_type = type_annotations.at(arg.name());
+      // TODO(rahulkamat): Modify AddDefLine() to add return type annotation to avoid
+      // erasing ":\n" from the end of the def line
       result_.erase(result_.length() - 2);
-      strings::StrAppend(&result_, " -> ", type_annotations[arg.name()], ":\n");
+      strings::StrAppend(&result_, " -> ", return_type, ":\n");
     }
   }
 }
@@ -829,8 +826,9 @@ void GenEagerPythonOp::AddEagerFunctionTeardown(
 
 bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
     const string& parameters, const std::vector<string>& output_sizes,
-    const string& eager_not_allowed_error, std::unordered_map<string, string>& type_annotations) {
-  if (type_annotate_op_) {
+    const string& eager_not_allowed_error,
+    const std::unordered_map<string, string>& type_annotations) {
+  if (add_type_annotations_) {
     GenerateTypeVars(type_annotations);
   }
   if (api_def_.visibility() == ApiDef::VISIBLE) {
@@ -839,7 +837,7 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
 
   AddExport();
   AddDefLine(function_name_, parameters);
-  if (type_annotate_op_) {
+  if (add_type_annotations_) {
     AddReturnTypeAnnotation(type_annotations);
   }
   AddDocStringDescription();
@@ -877,11 +875,11 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
 bool GenEagerPythonOp::AddEagerFallbackCode(
     const string& parameters, const std::vector<string>& output_sizes,
     const string& num_outputs_expr, const string& eager_not_allowed_error,
-    std::unordered_map<string, string>& type_annotations) {
+    const std::unordered_map<string, string>& type_annotations) {
   AddDefLine(
       strings::StrCat(function_name_, kEagerFallbackSuffix),
       strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
-  if (type_annotate_op_) {
+  if (add_type_annotations_) {
     AddReturnTypeAnnotation(type_annotations);
   }
   if (!eager_not_allowed_error.empty()) {
@@ -1133,7 +1131,7 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
 string GetPythonOpsImpl(const OpList& ops, const ApiDefMap& api_defs,
                         const std::vector<string>& hidden_ops,
                         const string& source_file_name = "",
-                        std::unordered_set<string> type_annotate_ops = {}) {
+                        const std::unordered_set<string> type_annotate_ops = {}) {
   string result;
   // Header
   // TODO(josh11b): Mention the library for which wrappers are being generated.
@@ -1211,10 +1209,11 @@ from typing import TypeVar
       continue;
     }
 
-    const bool type_annotate_op = type_annotate_ops.find(op_def.name()) != type_annotate_ops.end();
+    auto iter = type_annotate_ops.find(op_def.name());
+    bool add_type_annotations = iter != type_annotate_ops.end();
 
     strings::StrAppend(&result,
-                       GetEagerPythonOp(op_def, *api_def, function_name, type_annotate_op));
+                       GetEagerPythonOp(op_def, *api_def, function_name, add_type_annotations));
   }
 
   return result;
@@ -1225,14 +1224,14 @@ from typing import TypeVar
 string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
                     const std::vector<string>& hidden_ops,
                     const string& source_file_name,
-                    std::unordered_set<string> type_annotate_ops) {
+                    const std::unordered_set<string> type_annotate_ops) {
   return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name, type_annotate_ops);
 }
 
 void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
                     const std::vector<string>& hidden_ops,
                     const string& source_file_name,
-                    std::unordered_set<string> type_annotate_ops) {
+                    const std::unordered_set<string> type_annotate_ops) {
   printf("%s",
          GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name, type_annotate_ops).c_str());
 }
@@ -1245,16 +1244,14 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
   return GetPythonOpsImpl(ops, api_def_map, {});
 }
 
-string GetArgAnnotation(const auto& arg, std::unordered_map<string, string>& type_annotations) {
-  if (type_annotations.find(arg.type_attr()) != type_annotations.end()) {
+string GetArgAnnotation(const auto& arg, const std::unordered_map<string, string>& type_annotations) {
+  if (!arg.type_attr().empty()) {
     // Get the correct TypeVar if arg maps to an attr
-    return "_ops.Tensor[" + type_annotations[arg.type_attr()] + "]";
+    return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
   } else {
     // Get the dtype of the Tensor
     const string py_dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
-    if (dtype_type.find(py_dtype) != dtype_type.end()) {
-      return "_ops.Tensor[" + dtype_type[py_dtype] + "]";
-    }
+    return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
   }
 
   return "Any";
diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h
index 1a3b6c5e8f2..5dfc959b3ad 100644
--- a/tensorflow/python/framework/python_op_gen.h
+++ b/tensorflow/python/framework/python_op_gen.h
@@ -33,7 +33,7 @@ namespace tensorflow {
 string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
                     const std::vector<string>& hidden_ops,
                     const string& source_file_name,
-                    std::unordered_set<string> type_annotate_ops);
+                    const std::unordered_set<string> type_annotate_ops);
 
 // Prints the output of GetPrintOps to stdout.
 // hidden_ops should be a list of Op names that should get a leading _
@@ -43,7 +43,7 @@ string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
 void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
                     const std::vector<string>& hidden_ops,
                     const string& source_file_name,
-                    std::unordered_set<string> type_annotate_ops);
+                    const std::unordered_set<string> type_annotate_ops);
 
 // Get the python wrappers for a list of ops in a OpList.
 // `op_list_buf` should be a pointer to a buffer containing
@@ -55,7 +55,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len);
 // `arg` should be an input or output of an op
 // `type_annotations` should contain attr names mapped to TypeVar names
 string GetArgAnnotation(const auto& arg,
-                        std::unordered_map<string, string>& type_annotations);
+                        const std::unordered_map<string, string>& type_annotations);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc
index d0ef82857c4..adbdbbf06fb 100644
--- a/tensorflow/python/framework/python_op_gen_internal.cc
+++ b/tensorflow/python/framework/python_op_gen_internal.cc
@@ -513,11 +513,11 @@ const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
 }
 
 GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
-                         const string& function_name, const bool type_annotate_op)
+                         const string& function_name, bool add_type_annotations)
     : op_def_(op_def),
       api_def_(api_def),
       function_name_(function_name),
-      type_annotate_op_(type_annotate_op),
+      add_type_annotations_(add_type_annotations),
       num_outs_(op_def.output_arg_size()) {}
 
 GenPythonOp::~GenPythonOp() {}
diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h
index 5229bffc5d0..08d9b3c8a66 100644
--- a/tensorflow/python/framework/python_op_gen_internal.h
+++ b/tensorflow/python/framework/python_op_gen_internal.h
@@ -71,7 +71,7 @@ class ParamNames {
 class GenPythonOp {
  public:
   GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
-              const string& function_name, const bool type_annotate_op_);
+              const string& function_name, bool add_type_annotations_);
   virtual ~GenPythonOp();
 
   virtual string Code();
@@ -98,7 +98,7 @@ class GenPythonOp {
   const OpDef& op_def_;
   const ApiDef& api_def_;
   const string function_name_;
-  const bool type_annotate_op_;
+  bool add_type_annotations_;
   const int num_outs_;
 
   // Return value from Code() is prelude_ + result_.
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
index dcaea53100e..c3ef4202d2a 100644
--- a/tensorflow/python/framework/python_op_gen_main.cc
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -109,7 +109,7 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
                        const std::vector<string>& api_def_dirs,
                        const string& source_file_name,
                        bool op_list_is_whitelist,
-                       std::unordered_set<string> type_annotate_ops) {
+                       const std::unordered_set<string> type_annotate_ops) {
   OpList ops;
   OpRegistry::Global()->Export(false, &ops);
 
@@ -159,7 +159,7 @@ int main(int argc, char* argv[]) {
       argv[1], ",", tensorflow::str_util::SkipEmpty());
 
   // Add op name to this set to add type annotations
-  std::unordered_set<tensorflow::string> type_annotate_ops {
+  const std::unordered_set<tensorflow::string> type_annotate_ops {
   };
 
   if (argc == 2) {
diff --git a/tensorflow/python/framework/python_op_gen_test.cc b/tensorflow/python/framework/python_op_gen_test.cc
index cf6566ea7ae..5fff1a1d111 100644
--- a/tensorflow/python/framework/python_op_gen_test.cc
+++ b/tensorflow/python/framework/python_op_gen_test.cc
@@ -261,7 +261,7 @@ TEST(PythonOpGen, TypeAnnotateDefaultParams) {
 
   string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
 
-  const string params = "def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1: bool = False, var2: int = 0, name=None)";
+  const string params = "def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1:bool=False, var2:int=0, name=None)";
   const string params_fallback = "def foo_bar_eager_fallback(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1: bool, var2: int, name, ctx)";
   ExpectHasSubstr(code, params);
   ExpectHasSubstr(code, params_fallback);