diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 0e97cfc8b3f..fffd9f9de92 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -152,6 +152,12 @@ Status MlirFunctionOptimizationPass::Run( import_config.graph_as_function = true; import_config.control_outputs = *control_ret_node_names; import_config.upgrade_legacy = true; + // Disable shape inference during import as some TensorFlow op fails during + // shape inference with dynamic shaped operands. This in turn causes the + // import to fail. Shape inference during import is going to be removed and + // the shape inference pass is run early in the pass pipeline, shape inference + // during import is not necessary. + import_config.enable_shape_inference = false; auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def, import_config, &context); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/shape-attrs.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/shape-attrs.pbtxt new file mode 100644 index 00000000000..e1c5c1d7bf9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/shape-attrs.pbtxt @@ -0,0 +1,334 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false -tf-graph-as-function %s -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s + +node { + name: "args_0" + op: "_Arg" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "args_1" + op: "_Arg" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "index" + value { + i: 1 + } + } +} +node { + name: "args_2" + op: "_Arg" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "index" + value { + i: 2 + } + } +} +node { + name: "args_3" + op: "_Arg" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "index" + value { + i: 3 + } + } +} +node { + name: "args_4" + op: "_Arg" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "index" + value { + i: 4 + } + } +} +node { + name: "IteratorGetNext" + op: "IteratorGetNext" + input: "args_0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 1 + } + dim { + size: 8 + } + } + shape { + dim { + size: 2 + } + dim { + size: -1 + } + dim { + size: 16 + } + } + shape { + unknown_rank: true + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +node { + name: "IteratorGetNextSync" + op: "IteratorGetNextSync" + input: "args_0" + attr { + key: "output_shapes" + value { + list { + shape { + unknown_rank: true + } + shape { + dim { + size: 3 + } + dim { + size: 24 + } + } + shape { + dim { + size: -1 + } + dim { + size: 4 + } + dim { + size: 32 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + } + } + } +} +node { + name: "MultiDeviceIteratorGetNextFromShard" + op: "MultiDeviceIteratorGetNextFromShard" + input: "args_2" + input: "args_3" + input: "args_4" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 5 + } + dim { + size: 40 + } + } + shape { + unknown_rank: true + } + shape { + dim { + size: 6 + } + dim { + size: 48 + } + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_HALF + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} +node { + name: "InfeedDequeueTuple" + op: "InfeedDequeueTuple" + attr { + key: "shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + dim { + size: -1 + } + } + shape { + unknown_rank: true + } + shape { + dim { + size: 7 + } + dim { + size: 56 + } + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_UINT16 + type: DT_UINT32 + type: DT_UINT64 + } + } + } +} +node { + name: "InfeedDequeue_0" + op: "InfeedDequeue" + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 8 + } + dim { + size: -1 + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT8 + } + } +} +node { + name: "InfeedDequeue_1" + op: "InfeedDequeue" + attr { + key: "shape" + value { + shape { + dim { + size: 8 + } + dim { + size: 64 + } + } + } + } + attr { + key: "dtype" + value { + type: DT_UINT8 + } + } +} +node { + name: "InfeedDequeue_2" + op: "InfeedDequeue" + attr { + key: "shape" + value { + shape { + unknown_rank: true + } + } + } + attr { + key: "dtype" + value { + type: DT_BOOL + } + } +} + +# CHECK-DAG: tf.IteratorGetNext{{.+}}-> (tensor<1x8xbf16>, tensor<2x?x16xf32>, tensor<*xf64>) +# CHECK-DAG: tf.IteratorGetNextSync{{.+}}-> (tensor<*xi16>, tensor<3x24xi32>, tensor) +# CHECK-DAG: tf.MultiDeviceIteratorGetNextFromShard{{.+}}-> (tensor<5x40xf16>, tensor<*xcomplex>, tensor<6x48x?xcomplex>) +# CHECK-DAG: tf.InfeedDequeueTuple{{.+}}-> (tensor, tensor<*xui32>, tensor<7x56xui64>) +# CHECK-DAG: tf.InfeedDequeue{{.+}}-> tensor +# CHECK-DAG: tf.InfeedDequeue{{.+}}-> tensor<8x64xui8> +# CHECK-DAG: tf.InfeedDequeue{{.+}}-> tensor<*xi1> diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 64b3fc88892..4fd9498c432 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -980,6 +980,31 @@ StatusOr ImporterBase::InferOutputType(const Node& node, int idx, } } + auto type_from_array_attr = [&node, &idx, &builder]( + absl::string_view output_shape_attr, + absl::string_view element_type_attr) { + auto* output_shapes = node.attrs().Find(output_shape_attr); + auto* element_types = node.attrs().Find(element_type_attr); + const auto& output_shape = output_shapes->list().shape(idx); + const auto& element_type = element_types->list().type(idx); + return ConvertToMlirTensorType(output_shape, element_type, &builder); + }; + + if (node.type_string() == "IteratorGetNext" || + node.type_string() == "IteratorGetNextSync" || + node.type_string() == "MultiDeviceIteratorGetNextFromShard") + return type_from_array_attr("output_shapes", "output_types"); + + if (node.type_string() == "InfeedDequeueTuple") + return type_from_array_attr("shapes", "dtypes"); + + if (node.type_string() == "InfeedDequeue") { + assert(idx == 0); + const auto& output_shape = node.attrs().Find("shape")->shape(); + const auto& element_type = node.attrs().Find("dtype")->type(); + return ConvertToMlirTensorType(output_shape, element_type, &builder); + } + // Returns a simple, more conservative unranked tensor type. auto default_type = [&]() -> StatusOr { mlir::Type element_type;