diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-input-shapes.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-input-shapes.pbtxt new file mode 100644 index 00000000000..fc27e82d20e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-input-shapes.pbtxt @@ -0,0 +1,110 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s + +# Verify that the _input_shapes attribute of the FunctionDef is respected. +# This also checks that the output type is correctly inferred based on +# that. +#CHECK: func @identity_function0(%arg0: tensor) -> tensor + +node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + experimental_debug_info { + } +} +node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + experimental_debug_info { + } +} +node { + name: "If" + op: "If" + input: "Placeholder" + input: "Placeholder_1" + attr { + key: "Tcond" + value { + type: DT_BOOL + } + } + attr { + key: "Tin" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "else_branch" + value { + func { + name: "identity_function" + } + } + } + attr { + key: "then_branch" + value { + func { + name: "identity_function" + } + } + } + experimental_debug_info { + } +} +library { + function { + signature { + name: "identity_function" + input_arg { + name: "identity_input" + type: DT_INT32 + } + output_arg { + name: "identity_output" + type: DT_INT32 + } + } + ret { + key: "identity_output" + value: "identity_input" + } + attr { + key: "_input_shapes" + value { + list { + shape { + } + } + } + } + } +} +versions { + producer: 29 + min_consumer: 12 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 34cdc609164..17cc4d3282a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -476,7 +476,8 @@ Status ImporterBase::AddNodesToShapeRefiner() { auto it = specs_.inputs.find(node->name()); if (it != specs_.inputs.end()) { auto node_name = node->op_def().name(); - if (node_name != "Placeholder" && node_name != "LegacyFedInput") { + if (node_name != "Placeholder" && node_name != "LegacyFedInput" && + node_name != "_Arg") { // We do not handle the case where the input node has multple outputs if (node->num_outputs() > 1) { return errors::FailedPrecondition(absl::StrCat( @@ -845,10 +846,28 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr)); } - // Converts the graph to a MLIR function and adds it to the module. Uses the - // default node spec without any inputs or outputs as the function graph has - // special '_Arg' and '_Retval' ops for argument and return values. + // Converts the graph to a MLIR function and adds it to the module. + // We populate the NodeSpec so that all the _Arg ops get their shape + // added correctly. NodeSpecs specs; + for (const auto& name_and_value : func_def->attr()) { + if (name_and_value.first == "_input_shapes") { + auto& list = name_and_value.second.list(); + auto& signature = func_def->signature(); + for (int i = 0; i < list.shape_size(); i++) { + auto& input_arg = signature.input_arg(i); + auto& array_info = specs.inputs[input_arg.name()]; + array_info.imported_dtype = input_arg.type(); + array_info.shape = list.shape(i); + // TODO(b/140464702): These fields should not be exposed here. + // Seems like a layering violation. Initialize them anyway. + array_info.final_dtype = input_arg.type(); + array_info.min_value = 0.0; + array_info.max_value = 0.0; + } + } + } + ImporterBase child_importer(graph_flib_, debug_info_, specs, module_, tf_name_to_mlir_name_); TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph)); @@ -1399,16 +1418,22 @@ StatusOr ImporterBase::InferLibFunctionType( const FunctionBody& fbody) { mlir::Builder builder(context_); + // The FunctionBody contains a graph with a single-output _Arg node for each + // function argument and a single-input _Retval node for each function return + // value. + // + // We already populated the ShapeRefiner with all the information about the + // shapes of these graph edges, so we just query it to build the corresponding + // MLIR function type signature. + llvm::SmallVector arg_types; arg_types.reserve(fbody.arg_types.size()); - for (auto dataType : fbody.arg_types) { - mlir::Type element_type; - TF_RETURN_IF_ERROR( - ::tensorflow::ConvertDataType(dataType, builder, &element_type)); - // TODO(hinsu): Derive shape of function arguments based on shapes available - // at call sites of this function. That way it is possible to have a - // partially known shape in some cases instead of unranked tensor types. - arg_types.push_back(builder.getTensorType(element_type)); + for (auto arg : fbody.arg_nodes) { + // Find node in the graph using the node id instead of using `arg` directly + // because the graph has been cloned. + auto* node = graph_->FindNodeId(arg->id()); + TF_ASSIGN_OR_RETURN(auto type, InferOutputType(*node, /*idx=*/0, builder)); + arg_types.push_back(type); } llvm::SmallVector ret_types; @@ -1417,9 +1442,6 @@ StatusOr ImporterBase::InferLibFunctionType( // Find node in the graph using the node id instead of using `ret` directly // because the graph has been cloned. auto* node = graph_->FindNodeId(ret->id()); - - // Return type of the function is type of the only input of the respective - // return node in the function. TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder)); ret_types.push_back(type); }