diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-input-shapes.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-input-shapes.pbtxt new file mode 100644 index 00000000000..61ccf82af77 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-input-shapes.pbtxt @@ -0,0 +1,39 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s + +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } +} + +node { + name: "func0" + op: "func_name" + input: "input" +} + +library { + function { + signature { + name: "func_name" + input_arg { + name: "arg0" + type: DT_BOOL + } + } + ret { + key: "retval0" + value: "arg0" + } + attr: { + key: "_input_shapes" + value: { + } + } + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index d41dca5e8e9..ffe3e1e6ee8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -1292,17 +1292,23 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { if (name_and_value.first == "_input_shapes") { auto& list = name_and_value.second.list(); auto& signature = func_def->signature(); - if (list.shape_size() != signature.input_arg_size()) { + // Some models have "_input_shapes" attribute, but with its value empty + if (list.shape_size() > 0 && + list.shape_size() != signature.input_arg_size()) { return errors::FailedPrecondition( "Number of input arguments must be equal to the length of " "_input_shapes attribute in function '", StringRefToView(func_name), "'."); } - for (int i = 0; i < list.shape_size(); i++) { + for (int i = 0; i < signature.input_arg_size(); i++) { auto& input_arg = signature.input_arg(i); auto& array_info = specs.inputs[input_arg.name()]; array_info.imported_dtype = input_arg.type(); - array_info.shape = list.shape(i); + // set to unranked for empty "_input_shapes" attribute + if (list.shape_size() > 0) + array_info.shape = list.shape(i); + else + array_info.shape.set_unknown_rank(true); } } }