Fix MatMul with transpose_a

PiperOrigin-RevId: 261840270
This commit is contained in:
Yu-Cheng Ling 2019-08-05 22:10:12 -07:00 committed by TensorFlower Gardener
parent f8e323e0b2
commit 4623b733f7
2 changed files with 64 additions and 31 deletions

View File

@ -69,8 +69,6 @@ KNOWN_BUGS = {
# TOCO doesn't support scalars as input.
# Concat doesn't work with a single input tensor
r"concat.*num_tensors=1": "67378344",
# Transposition in MatMul is not fully supported.
"fully_connected.*transpose_a=True": "67586970",
# Softmax graphs are too complex.
r"softmax.*dim=0": "67749831",
# BatchToSpaceND only supports 4D tensors.
@ -2178,6 +2176,12 @@ def make_fully_connected_tests(options):
"transpose_a": [False],
"transpose_b": [True],
"constant_filter": [True, False],
}, {
"shape1": [[5, 3]],
"shape2": [[5, 3]],
"transpose_a": [True],
"transpose_b": [False],
"constant_filter": [True, False],
}]
def build_graph(parameters):

View File

@ -66,25 +66,69 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
const auto* matmul_op =
static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
// Handling transposition of the first input here isn't very simple because
// we need to know the actual shape in order to produce a proper
// TransposeOperator. However, the second input is supposed to be 2D, so we
// can actually handle transposition of that matrix, which happens to be more
// common anyway.
auto refresh_matmul_iterator = [&model, &matmul_it, &matmul_op]() {
matmul_it = std::find_if(model->operators.begin(), model->operators.end(),
[matmul_op](const std::unique_ptr<Operator>& op) {
return op.get() == matmul_op;
});
DCHECK_EQ(matmul_it->get(), matmul_op);
};
string input_lhs = matmul_op->inputs[0];
string input_rhs = matmul_op->inputs[1];
// Handle `transpose_a` with best effort: If the dimension of lhs is known,
// insert a `Transpose` op.
if (matmul_op->transpose_a) {
AddMessageF(
"Not replacing %s by a FullyConnected operator, because it has "
"the transpose_a attribute",
LogName(*matmul_op));
return ::tensorflow::Status::OK();
Array& lhs_array = model->GetArray(input_lhs);
if (!lhs_array.has_shape()) {
AddMessageF(
"Not replacing %s by a FullyConnected operator, because it has "
"the transpose_a attribute and LHS has no shape",
LogName(*matmul_op));
return ::tensorflow::Status::OK();
}
int dimensions_count = lhs_array.shape().dimensions_count();
if (dimensions_count < 2) {
return ::tensorflow::errors::InvalidArgument(
"Inputs of MatMul should have dimension >= 2. Got %d dimensions",
dimensions_count);
}
// Create a permutation vector to exchange the last 2 dimensions.
// E.g. For 4D, create [0, 1, 3, 2].
std::vector<int> perm;
perm.reserve(dimensions_count);
for (int i = 0; i < dimensions_count; ++i) {
perm.push_back(i);
}
std::swap(perm[dimensions_count - 1], perm[dimensions_count - 2]);
auto* transpose_op = new TransposeOperator;
transpose_op->inputs = {
input_lhs,
CreateInt32Array(
model, AvailableArrayName(*model, input_lhs + "/transpose/perm"),
perm)};
transpose_op->outputs = {
AvailableArrayName(*model, input_lhs + "/transpose")};
model->GetOrCreateArray(transpose_op->outputs[0]);
model->operators.emplace(matmul_it, transpose_op);
// Sanity check
DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_lhs));
input_lhs = transpose_op->outputs[0];
refresh_matmul_iterator();
}
// TODO(b/138662017): The following code assumes that RHS is 2D. This isn't
// always true in TensorFlow.
//
// Reorder the axes on the second input. TensorFlow uses row-major ordering
// on both inputs, however this is inefficient for the FullyConnected
// operator. We'll transpose the second input to be in column-major order now
// and let constant propagation optimize things (if possible).
string input_lhs = matmul_op->inputs[0];
string input_rhs = matmul_op->inputs[1];
if (!matmul_op->transpose_b) {
// Need to transpose input_rhs, by inserting a TransposeOperator.
// First, check if there already is a TransposeOperator transposing that
@ -108,6 +152,7 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
model->operators.emplace(matmul_it, transpose_op);
// Sanity check
DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
refresh_matmul_iterator();
} else {
AddMessageF(
"While replacing %s by a FullyConnected operator, reused existing "
@ -118,15 +163,6 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
input_rhs = transpose_op->outputs[0];
}
// Refresh iterator.
matmul_it = model->operators.begin();
for (; matmul_it != model->operators.end(); ++matmul_it) {
if (matmul_it->get() == matmul_op) {
break;
}
}
DCHECK_EQ(matmul_it->get(), matmul_op);
// Construct the new FullyConnectedOperator.
auto* fc_op = new FullyConnectedOperator;
fc_op->inputs = {input_lhs, input_rhs};
@ -181,14 +217,7 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
}
// We may have just invalidated matmul_it, so let's refresh it now.
matmul_it = model->operators.begin();
for (; matmul_it != model->operators.end(); ++matmul_it) {
if (matmul_it->get() == matmul_op) {
break;
}
}
CHECK(matmul_it != model->operators.end());
CHECK(matmul_it->get() == matmul_op);
refresh_matmul_iterator();
} else {
AddMessageF("Replacing %s by a FullyConnected operator",
LogName(*matmul_op));