diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 473eae43cd2..ecece1655ef 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -443,7 +443,7 @@ std::unordered_map GenEagerPythonOp::GetTypeAnnotations() { for (const auto& arg : op_def_.input_arg()) { // TODO(rahulkamat): Add type annotations to args that accept a sequence of Tensors if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue; - type_annotations[arg.name()] = GetArgAnnotation(arg.type_attr(), arg.type(), type_annotations); + type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations); } // TODO(rahulkamat): Add type annotations to handle return types of a sequence of Tensors. @@ -451,7 +451,7 @@ std::unordered_map GenEagerPythonOp::GetTypeAnnotations() { if (op_def_.output_arg_size() == 1) { const auto& arg = op_def_.output_arg(0); if (arg.number_attr().empty() && arg.type_list_attr().empty()) - type_annotations[arg.name()] = GetArgAnnotation(arg.type_attr(), arg.type(), type_annotations); + type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations); } return type_annotations; @@ -1246,13 +1246,13 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) { return GetPythonOpsImpl(ops, api_def_map, {}); } -string GetArgAnnotation(const string& arg_type_attr, DataType arg_type, const std::unordered_map& type_annotations) { - if (!arg_type_attr.empty()) { +string GetArgAnnotation(const OpDef::ArgDef& arg, const std::unordered_map& type_annotations) { + if (!arg.type_attr().empty()) { // Get the correct TypeVar if arg maps to an attr - return "_ops.Tensor[" + type_annotations.at(arg_type_attr) + "]"; + return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]"; } else { // Get the dtype of the Tensor - const string py_dtype = python_op_gen_internal::DataTypeToPython(arg_type, "_dtypes."); + const string py_dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes."); return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]"; } diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h index 75f04952d48..178e078a81b 100644 --- a/tensorflow/python/framework/python_op_gen.h +++ b/tensorflow/python/framework/python_op_gen.h @@ -54,7 +54,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len); // Get the type annotation for an arg // `arg` should be an input or output of an op // `type_annotations` should contain attr names mapped to TypeVar names -string GetArgAnnotation(const string& arg_type_attr, DataType arg_type, +string GetArgAnnotation(const OpDef::ArgDef& arg, const std::unordered_map& type_annotations); } // namespace tensorflow