Allow empty value for "_input_shapes" attr of functions in graphdef library

PiperOrigin-RevId: 360279759
Change-Id: I14edacb9aad6d4b995ef231520acad2b752b029d
This commit is contained in:
Hongmin Fan 2021-03-01 14:35:43 -08:00 committed by TensorFlower Gardener
parent 21745f6a97
commit 6aca90258a
2 changed files with 48 additions and 3 deletions

View File

@ -0,0 +1,39 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s
node {
name: "input"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
}
node {
name: "func0"
op: "func_name"
input: "input"
}
library {
function {
signature {
name: "func_name"
input_arg {
name: "arg0"
type: DT_BOOL
}
}
ret {
key: "retval0"
value: "arg0"
}
attr: {
key: "_input_shapes"
value: {
}
}
}
}

View File

@ -1292,17 +1292,23 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
if (name_and_value.first == "_input_shapes") {
auto& list = name_and_value.second.list();
auto& signature = func_def->signature();
if (list.shape_size() != signature.input_arg_size()) {
// Some models have "_input_shapes" attribute, but with its value empty
if (list.shape_size() > 0 &&
list.shape_size() != signature.input_arg_size()) {
return errors::FailedPrecondition(
"Number of input arguments must be equal to the length of "
"_input_shapes attribute in function '",
StringRefToView(func_name), "'.");
}
for (int i = 0; i < list.shape_size(); i++) {
for (int i = 0; i < signature.input_arg_size(); i++) {
auto& input_arg = signature.input_arg(i);
auto& array_info = specs.inputs[input_arg.name()];
array_info.imported_dtype = input_arg.type();
array_info.shape = list.shape(i);
// set to unranked for empty "_input_shapes" attribute
if (list.shape_size() > 0)
array_info.shape = list.shape(i);
else
array_info.shape.set_unknown_rank(true);
}
}
}