Fix MatMul with transpose_a
PiperOrigin-RevId: 261840270
This commit is contained in:
parent
f8e323e0b2
commit
4623b733f7
@ -69,8 +69,6 @@ KNOWN_BUGS = {
|
|||||||
# TOCO doesn't support scalars as input.
|
# TOCO doesn't support scalars as input.
|
||||||
# Concat doesn't work with a single input tensor
|
# Concat doesn't work with a single input tensor
|
||||||
r"concat.*num_tensors=1": "67378344",
|
r"concat.*num_tensors=1": "67378344",
|
||||||
# Transposition in MatMul is not fully supported.
|
|
||||||
"fully_connected.*transpose_a=True": "67586970",
|
|
||||||
# Softmax graphs are too complex.
|
# Softmax graphs are too complex.
|
||||||
r"softmax.*dim=0": "67749831",
|
r"softmax.*dim=0": "67749831",
|
||||||
# BatchToSpaceND only supports 4D tensors.
|
# BatchToSpaceND only supports 4D tensors.
|
||||||
@ -2178,6 +2176,12 @@ def make_fully_connected_tests(options):
|
|||||||
"transpose_a": [False],
|
"transpose_a": [False],
|
||||||
"transpose_b": [True],
|
"transpose_b": [True],
|
||||||
"constant_filter": [True, False],
|
"constant_filter": [True, False],
|
||||||
|
}, {
|
||||||
|
"shape1": [[5, 3]],
|
||||||
|
"shape2": [[5, 3]],
|
||||||
|
"transpose_a": [True],
|
||||||
|
"transpose_b": [False],
|
||||||
|
"constant_filter": [True, False],
|
||||||
}]
|
}]
|
||||||
|
|
||||||
def build_graph(parameters):
|
def build_graph(parameters):
|
||||||
|
|||||||
@ -66,25 +66,69 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
|||||||
const auto* matmul_op =
|
const auto* matmul_op =
|
||||||
static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
|
static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
|
||||||
|
|
||||||
// Handling transposition of the first input here isn't very simple because
|
auto refresh_matmul_iterator = [&model, &matmul_it, &matmul_op]() {
|
||||||
// we need to know the actual shape in order to produce a proper
|
matmul_it = std::find_if(model->operators.begin(), model->operators.end(),
|
||||||
// TransposeOperator. However, the second input is supposed to be 2D, so we
|
[matmul_op](const std::unique_ptr<Operator>& op) {
|
||||||
// can actually handle transposition of that matrix, which happens to be more
|
return op.get() == matmul_op;
|
||||||
// common anyway.
|
});
|
||||||
|
DCHECK_EQ(matmul_it->get(), matmul_op);
|
||||||
|
};
|
||||||
|
|
||||||
|
string input_lhs = matmul_op->inputs[0];
|
||||||
|
string input_rhs = matmul_op->inputs[1];
|
||||||
|
|
||||||
|
// Handle `transpose_a` with best effort: If the dimension of lhs is known,
|
||||||
|
// insert a `Transpose` op.
|
||||||
if (matmul_op->transpose_a) {
|
if (matmul_op->transpose_a) {
|
||||||
AddMessageF(
|
Array& lhs_array = model->GetArray(input_lhs);
|
||||||
"Not replacing %s by a FullyConnected operator, because it has "
|
if (!lhs_array.has_shape()) {
|
||||||
"the transpose_a attribute",
|
AddMessageF(
|
||||||
LogName(*matmul_op));
|
"Not replacing %s by a FullyConnected operator, because it has "
|
||||||
return ::tensorflow::Status::OK();
|
"the transpose_a attribute and LHS has no shape",
|
||||||
|
LogName(*matmul_op));
|
||||||
|
return ::tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
int dimensions_count = lhs_array.shape().dimensions_count();
|
||||||
|
if (dimensions_count < 2) {
|
||||||
|
return ::tensorflow::errors::InvalidArgument(
|
||||||
|
"Inputs of MatMul should have dimension >= 2. Got %d dimensions",
|
||||||
|
dimensions_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a permutation vector to exchange the last 2 dimensions.
|
||||||
|
// E.g. For 4D, create [0, 1, 3, 2].
|
||||||
|
std::vector<int> perm;
|
||||||
|
perm.reserve(dimensions_count);
|
||||||
|
for (int i = 0; i < dimensions_count; ++i) {
|
||||||
|
perm.push_back(i);
|
||||||
|
}
|
||||||
|
std::swap(perm[dimensions_count - 1], perm[dimensions_count - 2]);
|
||||||
|
|
||||||
|
auto* transpose_op = new TransposeOperator;
|
||||||
|
transpose_op->inputs = {
|
||||||
|
input_lhs,
|
||||||
|
CreateInt32Array(
|
||||||
|
model, AvailableArrayName(*model, input_lhs + "/transpose/perm"),
|
||||||
|
perm)};
|
||||||
|
transpose_op->outputs = {
|
||||||
|
AvailableArrayName(*model, input_lhs + "/transpose")};
|
||||||
|
model->GetOrCreateArray(transpose_op->outputs[0]);
|
||||||
|
model->operators.emplace(matmul_it, transpose_op);
|
||||||
|
// Sanity check
|
||||||
|
DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_lhs));
|
||||||
|
input_lhs = transpose_op->outputs[0];
|
||||||
|
|
||||||
|
refresh_matmul_iterator();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(b/138662017): The following code assumes that RHS is 2D. This isn't
|
||||||
|
// always true in TensorFlow.
|
||||||
|
//
|
||||||
// Reorder the axes on the second input. TensorFlow uses row-major ordering
|
// Reorder the axes on the second input. TensorFlow uses row-major ordering
|
||||||
// on both inputs, however this is inefficient for the FullyConnected
|
// on both inputs, however this is inefficient for the FullyConnected
|
||||||
// operator. We'll transpose the second input to be in column-major order now
|
// operator. We'll transpose the second input to be in column-major order now
|
||||||
// and let constant propagation optimize things (if possible).
|
// and let constant propagation optimize things (if possible).
|
||||||
string input_lhs = matmul_op->inputs[0];
|
|
||||||
string input_rhs = matmul_op->inputs[1];
|
|
||||||
if (!matmul_op->transpose_b) {
|
if (!matmul_op->transpose_b) {
|
||||||
// Need to transpose input_rhs, by inserting a TransposeOperator.
|
// Need to transpose input_rhs, by inserting a TransposeOperator.
|
||||||
// First, check if there already is a TransposeOperator transposing that
|
// First, check if there already is a TransposeOperator transposing that
|
||||||
@ -108,6 +152,7 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
|||||||
model->operators.emplace(matmul_it, transpose_op);
|
model->operators.emplace(matmul_it, transpose_op);
|
||||||
// Sanity check
|
// Sanity check
|
||||||
DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
|
DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
|
||||||
|
refresh_matmul_iterator();
|
||||||
} else {
|
} else {
|
||||||
AddMessageF(
|
AddMessageF(
|
||||||
"While replacing %s by a FullyConnected operator, reused existing "
|
"While replacing %s by a FullyConnected operator, reused existing "
|
||||||
@ -118,15 +163,6 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
|||||||
input_rhs = transpose_op->outputs[0];
|
input_rhs = transpose_op->outputs[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh iterator.
|
|
||||||
matmul_it = model->operators.begin();
|
|
||||||
for (; matmul_it != model->operators.end(); ++matmul_it) {
|
|
||||||
if (matmul_it->get() == matmul_op) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DCHECK_EQ(matmul_it->get(), matmul_op);
|
|
||||||
|
|
||||||
// Construct the new FullyConnectedOperator.
|
// Construct the new FullyConnectedOperator.
|
||||||
auto* fc_op = new FullyConnectedOperator;
|
auto* fc_op = new FullyConnectedOperator;
|
||||||
fc_op->inputs = {input_lhs, input_rhs};
|
fc_op->inputs = {input_lhs, input_rhs};
|
||||||
@ -181,14 +217,7 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We may have just invalidated matmul_it, so let's refresh it now.
|
// We may have just invalidated matmul_it, so let's refresh it now.
|
||||||
matmul_it = model->operators.begin();
|
refresh_matmul_iterator();
|
||||||
for (; matmul_it != model->operators.end(); ++matmul_it) {
|
|
||||||
if (matmul_it->get() == matmul_op) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
CHECK(matmul_it != model->operators.end());
|
|
||||||
CHECK(matmul_it->get() == matmul_op);
|
|
||||||
} else {
|
} else {
|
||||||
AddMessageF("Replacing %s by a FullyConnected operator",
|
AddMessageF("Replacing %s by a FullyConnected operator",
|
||||||
LogName(*matmul_op));
|
LogName(*matmul_op));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user