diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index 643a36a906e..53e286f0765 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -69,8 +69,6 @@ KNOWN_BUGS = { # TOCO doesn't support scalars as input. # Concat doesn't work with a single input tensor r"concat.*num_tensors=1": "67378344", - # Transposition in MatMul is not fully supported. - "fully_connected.*transpose_a=True": "67586970", # Softmax graphs are too complex. r"softmax.*dim=0": "67749831", # BatchToSpaceND only supports 4D tensors. @@ -2178,6 +2176,12 @@ def make_fully_connected_tests(options): "transpose_a": [False], "transpose_b": [True], "constant_filter": [True, False], + }, { + "shape1": [[5, 3]], + "shape2": [[5, 3]], + "transpose_a": [True], + "transpose_b": [False], + "constant_filter": [True, False], }] def build_graph(parameters): diff --git a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index 1aa30bcf1f3..ac95d609e91 100644 --- a/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -66,25 +66,69 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model, const auto* matmul_op = static_cast(matmul_it->get()); - // Handling transposition of the first input here isn't very simple because - // we need to know the actual shape in order to produce a proper - // TransposeOperator. However, the second input is supposed to be 2D, so we - // can actually handle transposition of that matrix, which happens to be more - // common anyway. + auto refresh_matmul_iterator = [&model, &matmul_it, &matmul_op]() { + matmul_it = std::find_if(model->operators.begin(), model->operators.end(), + [matmul_op](const std::unique_ptr& op) { + return op.get() == matmul_op; + }); + DCHECK_EQ(matmul_it->get(), matmul_op); + }; + + string input_lhs = matmul_op->inputs[0]; + string input_rhs = matmul_op->inputs[1]; + + // Handle `transpose_a` with best effort: If the dimension of lhs is known, + // insert a `Transpose` op. if (matmul_op->transpose_a) { - AddMessageF( - "Not replacing %s by a FullyConnected operator, because it has " - "the transpose_a attribute", - LogName(*matmul_op)); - return ::tensorflow::Status::OK(); + Array& lhs_array = model->GetArray(input_lhs); + if (!lhs_array.has_shape()) { + AddMessageF( + "Not replacing %s by a FullyConnected operator, because it has " + "the transpose_a attribute and LHS has no shape", + LogName(*matmul_op)); + return ::tensorflow::Status::OK(); + } + + int dimensions_count = lhs_array.shape().dimensions_count(); + if (dimensions_count < 2) { + return ::tensorflow::errors::InvalidArgument( + "Inputs of MatMul should have dimension >= 2. Got %d dimensions", + dimensions_count); + } + + // Create a permutation vector to exchange the last 2 dimensions. + // E.g. For 4D, create [0, 1, 3, 2]. + std::vector perm; + perm.reserve(dimensions_count); + for (int i = 0; i < dimensions_count; ++i) { + perm.push_back(i); + } + std::swap(perm[dimensions_count - 1], perm[dimensions_count - 2]); + + auto* transpose_op = new TransposeOperator; + transpose_op->inputs = { + input_lhs, + CreateInt32Array( + model, AvailableArrayName(*model, input_lhs + "/transpose/perm"), + perm)}; + transpose_op->outputs = { + AvailableArrayName(*model, input_lhs + "/transpose")}; + model->GetOrCreateArray(transpose_op->outputs[0]); + model->operators.emplace(matmul_it, transpose_op); + // Sanity check + DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_lhs)); + input_lhs = transpose_op->outputs[0]; + + refresh_matmul_iterator(); } + // TODO(b/138662017): The following code assumes that RHS is 2D. This isn't + // always true in TensorFlow. + // // Reorder the axes on the second input. TensorFlow uses row-major ordering // on both inputs, however this is inefficient for the FullyConnected // operator. We'll transpose the second input to be in column-major order now // and let constant propagation optimize things (if possible). - string input_lhs = matmul_op->inputs[0]; - string input_rhs = matmul_op->inputs[1]; if (!matmul_op->transpose_b) { // Need to transpose input_rhs, by inserting a TransposeOperator. // First, check if there already is a TransposeOperator transposing that @@ -108,6 +152,7 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model, model->operators.emplace(matmul_it, transpose_op); // Sanity check DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs)); + refresh_matmul_iterator(); } else { AddMessageF( "While replacing %s by a FullyConnected operator, reused existing " @@ -118,15 +163,6 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model, input_rhs = transpose_op->outputs[0]; } - // Refresh iterator. - matmul_it = model->operators.begin(); - for (; matmul_it != model->operators.end(); ++matmul_it) { - if (matmul_it->get() == matmul_op) { - break; - } - } - DCHECK_EQ(matmul_it->get(), matmul_op); - // Construct the new FullyConnectedOperator. auto* fc_op = new FullyConnectedOperator; fc_op->inputs = {input_lhs, input_rhs}; @@ -181,14 +217,7 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model, } // We may have just invalidated matmul_it, so let's refresh it now. - matmul_it = model->operators.begin(); - for (; matmul_it != model->operators.end(); ++matmul_it) { - if (matmul_it->get() == matmul_op) { - break; - } - } - CHECK(matmul_it != model->operators.end()); - CHECK(matmul_it->get() == matmul_op); + refresh_matmul_iterator(); } else { AddMessageF("Replacing %s by a FullyConnected operator", LogName(*matmul_op));