Fix MatMul with transpose_a
PiperOrigin-RevId: 261840270
This commit is contained in:
parent
f8e323e0b2
commit
4623b733f7
@ -69,8 +69,6 @@ KNOWN_BUGS = {
|
||||
# TOCO doesn't support scalars as input.
|
||||
# Concat doesn't work with a single input tensor
|
||||
r"concat.*num_tensors=1": "67378344",
|
||||
# Transposition in MatMul is not fully supported.
|
||||
"fully_connected.*transpose_a=True": "67586970",
|
||||
# Softmax graphs are too complex.
|
||||
r"softmax.*dim=0": "67749831",
|
||||
# BatchToSpaceND only supports 4D tensors.
|
||||
@ -2178,6 +2176,12 @@ def make_fully_connected_tests(options):
|
||||
"transpose_a": [False],
|
||||
"transpose_b": [True],
|
||||
"constant_filter": [True, False],
|
||||
}, {
|
||||
"shape1": [[5, 3]],
|
||||
"shape2": [[5, 3]],
|
||||
"transpose_a": [True],
|
||||
"transpose_b": [False],
|
||||
"constant_filter": [True, False],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
|
||||
@ -66,25 +66,69 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
||||
const auto* matmul_op =
|
||||
static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
|
||||
|
||||
// Handling transposition of the first input here isn't very simple because
|
||||
// we need to know the actual shape in order to produce a proper
|
||||
// TransposeOperator. However, the second input is supposed to be 2D, so we
|
||||
// can actually handle transposition of that matrix, which happens to be more
|
||||
// common anyway.
|
||||
auto refresh_matmul_iterator = [&model, &matmul_it, &matmul_op]() {
|
||||
matmul_it = std::find_if(model->operators.begin(), model->operators.end(),
|
||||
[matmul_op](const std::unique_ptr<Operator>& op) {
|
||||
return op.get() == matmul_op;
|
||||
});
|
||||
DCHECK_EQ(matmul_it->get(), matmul_op);
|
||||
};
|
||||
|
||||
string input_lhs = matmul_op->inputs[0];
|
||||
string input_rhs = matmul_op->inputs[1];
|
||||
|
||||
// Handle `transpose_a` with best effort: If the dimension of lhs is known,
|
||||
// insert a `Transpose` op.
|
||||
if (matmul_op->transpose_a) {
|
||||
AddMessageF(
|
||||
"Not replacing %s by a FullyConnected operator, because it has "
|
||||
"the transpose_a attribute",
|
||||
LogName(*matmul_op));
|
||||
return ::tensorflow::Status::OK();
|
||||
Array& lhs_array = model->GetArray(input_lhs);
|
||||
if (!lhs_array.has_shape()) {
|
||||
AddMessageF(
|
||||
"Not replacing %s by a FullyConnected operator, because it has "
|
||||
"the transpose_a attribute and LHS has no shape",
|
||||
LogName(*matmul_op));
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
int dimensions_count = lhs_array.shape().dimensions_count();
|
||||
if (dimensions_count < 2) {
|
||||
return ::tensorflow::errors::InvalidArgument(
|
||||
"Inputs of MatMul should have dimension >= 2. Got %d dimensions",
|
||||
dimensions_count);
|
||||
}
|
||||
|
||||
// Create a permutation vector to exchange the last 2 dimensions.
|
||||
// E.g. For 4D, create [0, 1, 3, 2].
|
||||
std::vector<int> perm;
|
||||
perm.reserve(dimensions_count);
|
||||
for (int i = 0; i < dimensions_count; ++i) {
|
||||
perm.push_back(i);
|
||||
}
|
||||
std::swap(perm[dimensions_count - 1], perm[dimensions_count - 2]);
|
||||
|
||||
auto* transpose_op = new TransposeOperator;
|
||||
transpose_op->inputs = {
|
||||
input_lhs,
|
||||
CreateInt32Array(
|
||||
model, AvailableArrayName(*model, input_lhs + "/transpose/perm"),
|
||||
perm)};
|
||||
transpose_op->outputs = {
|
||||
AvailableArrayName(*model, input_lhs + "/transpose")};
|
||||
model->GetOrCreateArray(transpose_op->outputs[0]);
|
||||
model->operators.emplace(matmul_it, transpose_op);
|
||||
// Sanity check
|
||||
DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_lhs));
|
||||
input_lhs = transpose_op->outputs[0];
|
||||
|
||||
refresh_matmul_iterator();
|
||||
}
|
||||
|
||||
// TODO(b/138662017): The following code assumes that RHS is 2D. This isn't
|
||||
// always true in TensorFlow.
|
||||
//
|
||||
// Reorder the axes on the second input. TensorFlow uses row-major ordering
|
||||
// on both inputs, however this is inefficient for the FullyConnected
|
||||
// operator. We'll transpose the second input to be in column-major order now
|
||||
// and let constant propagation optimize things (if possible).
|
||||
string input_lhs = matmul_op->inputs[0];
|
||||
string input_rhs = matmul_op->inputs[1];
|
||||
if (!matmul_op->transpose_b) {
|
||||
// Need to transpose input_rhs, by inserting a TransposeOperator.
|
||||
// First, check if there already is a TransposeOperator transposing that
|
||||
@ -108,6 +152,7 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
||||
model->operators.emplace(matmul_it, transpose_op);
|
||||
// Sanity check
|
||||
DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
|
||||
refresh_matmul_iterator();
|
||||
} else {
|
||||
AddMessageF(
|
||||
"While replacing %s by a FullyConnected operator, reused existing "
|
||||
@ -118,15 +163,6 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
||||
input_rhs = transpose_op->outputs[0];
|
||||
}
|
||||
|
||||
// Refresh iterator.
|
||||
matmul_it = model->operators.begin();
|
||||
for (; matmul_it != model->operators.end(); ++matmul_it) {
|
||||
if (matmul_it->get() == matmul_op) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
DCHECK_EQ(matmul_it->get(), matmul_op);
|
||||
|
||||
// Construct the new FullyConnectedOperator.
|
||||
auto* fc_op = new FullyConnectedOperator;
|
||||
fc_op->inputs = {input_lhs, input_rhs};
|
||||
@ -181,14 +217,7 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
||||
}
|
||||
|
||||
// We may have just invalidated matmul_it, so let's refresh it now.
|
||||
matmul_it = model->operators.begin();
|
||||
for (; matmul_it != model->operators.end(); ++matmul_it) {
|
||||
if (matmul_it->get() == matmul_op) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK(matmul_it != model->operators.end());
|
||||
CHECK(matmul_it->get() == matmul_op);
|
||||
refresh_matmul_iterator();
|
||||
} else {
|
||||
AddMessageF("Replacing %s by a FullyConnected operator",
|
||||
LogName(*matmul_op));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user