Pass single structured arg to function
This commit is contained in:
parent
0089b010bf
commit
eb500b17ac
@ -443,7 +443,7 @@ std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
|
||||
for (const auto& arg : op_def_.input_arg()) {
|
||||
// TODO(rahulkamat): Add type annotations to args that accept a sequence of Tensors
|
||||
if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
|
||||
type_annotations[arg.name()] = GetArgAnnotation(arg.type_attr(), arg.type(), type_annotations);
|
||||
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
|
||||
}
|
||||
|
||||
// TODO(rahulkamat): Add type annotations to handle return types of a sequence of Tensors.
|
||||
@ -451,7 +451,7 @@ std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
|
||||
if (op_def_.output_arg_size() == 1) {
|
||||
const auto& arg = op_def_.output_arg(0);
|
||||
if (arg.number_attr().empty() && arg.type_list_attr().empty())
|
||||
type_annotations[arg.name()] = GetArgAnnotation(arg.type_attr(), arg.type(), type_annotations);
|
||||
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
|
||||
}
|
||||
|
||||
return type_annotations;
|
||||
@ -1246,13 +1246,13 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||
return GetPythonOpsImpl(ops, api_def_map, {});
|
||||
}
|
||||
|
||||
string GetArgAnnotation(const string& arg_type_attr, DataType arg_type, const std::unordered_map<string, string>& type_annotations) {
|
||||
if (!arg_type_attr.empty()) {
|
||||
string GetArgAnnotation(const OpDef::ArgDef& arg, const std::unordered_map<string, string>& type_annotations) {
|
||||
if (!arg.type_attr().empty()) {
|
||||
// Get the correct TypeVar if arg maps to an attr
|
||||
return "_ops.Tensor[" + type_annotations.at(arg_type_attr) + "]";
|
||||
return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
|
||||
} else {
|
||||
// Get the dtype of the Tensor
|
||||
const string py_dtype = python_op_gen_internal::DataTypeToPython(arg_type, "_dtypes.");
|
||||
const string py_dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
|
||||
return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
|
||||
}
|
||||
|
||||
|
@ -54,7 +54,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len);
|
||||
// Get the type annotation for an arg
|
||||
// `arg` should be an input or output of an op
|
||||
// `type_annotations` should contain attr names mapped to TypeVar names
|
||||
string GetArgAnnotation(const string& arg_type_attr, DataType arg_type,
|
||||
string GetArgAnnotation(const OpDef::ArgDef& arg,
|
||||
const std::unordered_map<string, string>& type_annotations);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user