Pass single structured arg to function

This commit is contained in:
rahul-kamat 2020-06-24 21:23:32 +00:00
parent 0089b010bf
commit eb500b17ac
2 changed files with 7 additions and 7 deletions

View File

@ -443,7 +443,7 @@ std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
for (const auto& arg : op_def_.input_arg()) {
// TODO(rahulkamat): Add type annotations to args that accept a sequence of Tensors
if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
type_annotations[arg.name()] = GetArgAnnotation(arg.type_attr(), arg.type(), type_annotations);
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
}
// TODO(rahulkamat): Add type annotations to handle return types of a sequence of Tensors.
@ -451,7 +451,7 @@ std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
if (op_def_.output_arg_size() == 1) {
const auto& arg = op_def_.output_arg(0);
if (arg.number_attr().empty() && arg.type_list_attr().empty())
type_annotations[arg.name()] = GetArgAnnotation(arg.type_attr(), arg.type(), type_annotations);
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
}
return type_annotations;
@ -1246,13 +1246,13 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
return GetPythonOpsImpl(ops, api_def_map, {});
}
string GetArgAnnotation(const string& arg_type_attr, DataType arg_type, const std::unordered_map<string, string>& type_annotations) {
if (!arg_type_attr.empty()) {
string GetArgAnnotation(const OpDef::ArgDef& arg, const std::unordered_map<string, string>& type_annotations) {
if (!arg.type_attr().empty()) {
// Get the correct TypeVar if arg maps to an attr
return "_ops.Tensor[" + type_annotations.at(arg_type_attr) + "]";
return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
} else {
// Get the dtype of the Tensor
const string py_dtype = python_op_gen_internal::DataTypeToPython(arg_type, "_dtypes.");
const string py_dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
}

View File

@ -54,7 +54,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len);
// Get the type annotation for an arg
// `arg` should be an input or output of an op
// `type_annotations` should contain attr names mapped to TypeVar names
string GetArgAnnotation(const string& arg_type_attr, DataType arg_type,
string GetArgAnnotation(const OpDef::ArgDef& arg,
const std::unordered_map<string, string>& type_annotations);
} // namespace tensorflow