Allow empty value for "_input_shapes" attr of functions in graphdef library
PiperOrigin-RevId: 360279759 Change-Id: I14edacb9aad6d4b995ef231520acad2b752b029d
This commit is contained in:
parent
21745f6a97
commit
6aca90258a
@ -0,0 +1,39 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
name: "func0"
|
||||
op: "func_name"
|
||||
input: "input"
|
||||
}
|
||||
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "func_name"
|
||||
input_arg {
|
||||
name: "arg0"
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "retval0"
|
||||
value: "arg0"
|
||||
}
|
||||
attr: {
|
||||
key: "_input_shapes"
|
||||
value: {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1292,17 +1292,23 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
||||
if (name_and_value.first == "_input_shapes") {
|
||||
auto& list = name_and_value.second.list();
|
||||
auto& signature = func_def->signature();
|
||||
if (list.shape_size() != signature.input_arg_size()) {
|
||||
// Some models have "_input_shapes" attribute, but with its value empty
|
||||
if (list.shape_size() > 0 &&
|
||||
list.shape_size() != signature.input_arg_size()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Number of input arguments must be equal to the length of "
|
||||
"_input_shapes attribute in function '",
|
||||
StringRefToView(func_name), "'.");
|
||||
}
|
||||
for (int i = 0; i < list.shape_size(); i++) {
|
||||
for (int i = 0; i < signature.input_arg_size(); i++) {
|
||||
auto& input_arg = signature.input_arg(i);
|
||||
auto& array_info = specs.inputs[input_arg.name()];
|
||||
array_info.imported_dtype = input_arg.type();
|
||||
array_info.shape = list.shape(i);
|
||||
// set to unranked for empty "_input_shapes" attribute
|
||||
if (list.shape_size() > 0)
|
||||
array_info.shape = list.shape(i);
|
||||
else
|
||||
array_info.shape.set_unknown_rank(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user