Merge pull request #41790 from lgeiger:fix-matmul-fusion
PiperOrigin-RevId: 341026586 Change-Id: Idb578a4cf48c82abaad2f811612c4df4c2f2752c
This commit is contained in:
parent
5e38165d02
commit
7a6fa7dbc0
@ -286,7 +286,7 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||
|
||||
return %1 : tensor<4x2xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK: return %[[RES]] : tensor<4x2xf32>
|
||||
@ -384,7 +384,7 @@ func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> te
|
||||
|
||||
return %1 : tensor<4x2xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
|
||||
// CHECK: return %[[RES]] : tensor<4x2xf32>
|
||||
}
|
||||
|
@ -438,26 +438,28 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
return failure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return failure();
|
||||
|
||||
// Broadcast the constant operand of Mul if it isn't compatible to the
|
||||
// filter input. We only support broadcasting the operand along the depth
|
||||
// dimension, when the operand's depth is 1.
|
||||
Value new_const_val = constant_val;
|
||||
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) {
|
||||
auto original_shape = cst.getType().getShape();
|
||||
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
|
||||
original_shape.end());
|
||||
normalized_shape.push_back(1);
|
||||
auto new_cst = cst.reshape(RankedTensorType::get(
|
||||
normalized_shape, cst.getType().getElementType()));
|
||||
Type new_type = new_cst.getType();
|
||||
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
|
||||
return failure();
|
||||
}
|
||||
auto new_op =
|
||||
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
|
||||
new_const_val = new_op.getResult();
|
||||
// Only fuse multiplier if all dimensions other than the depth dimension
|
||||
// are equal to 1 since otherwise
|
||||
// `matmul(x, filter) * cst != matmul(x, filter * cst)`
|
||||
// even if `filter` and `cst` are be broadcastable.
|
||||
auto shape = cst.getType().getShape();
|
||||
if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
|
||||
|
||||
int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
|
||||
// Expand and transpose the multiplier since weights are using the
|
||||
// OHWI data format in TFLite.
|
||||
int64_t normalized_shape[2] = {element_size, 1};
|
||||
auto new_cst = cst.reshape(RankedTensorType::get(
|
||||
normalized_shape, cst.getType().getElementType()));
|
||||
Type new_type = new_cst.getType();
|
||||
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto new_op =
|
||||
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
|
||||
Value new_const_val = new_op.getResult();
|
||||
|
||||
// Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
|
||||
// TF::MulOp is used to fold the constant.
|
||||
// TODO(b/139192933): switch to the TFL constant folding
|
||||
|
@ -78,5 +78,21 @@ bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b) {
|
||||
return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
|
||||
}
|
||||
|
||||
bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape) {
|
||||
if (elements_shape.empty()) return true;
|
||||
|
||||
for (auto dim : elements_shape.drop_back(1)) {
|
||||
if (dim != 1) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsDimensionsDegenerateExceptLastOne(Attribute val) {
|
||||
if (auto ranked_type = val.getType().dyn_cast<RankedTensorType>()) {
|
||||
return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
@ -70,6 +70,10 @@ inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) {
|
||||
/// Returns whether the given `a` and `b` have broadcast-compatible
|
||||
/// types.
|
||||
bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b);
|
||||
// Returns true if every dimension of the attribute is 1 except the last one.
|
||||
bool IsDimensionsDegenerateExceptLastOne(mlir::Attribute val);
|
||||
// Returns true if every element is 1 except the last one.
|
||||
bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape);
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
@ -17,8 +17,12 @@ include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
|
||||
def BroadcastableElements :
|
||||
Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1)">>;
|
||||
|
||||
// Only fuse multiplier if all dimensions other than the channel dimension
|
||||
// are equal to 1.
|
||||
def CanFuseMulAndConv2D :
|
||||
Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1) && TFL::IsDimensionsDegenerateExceptLastOne($1)">>;
|
||||
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0.getDefiningOp())">>;
|
||||
@ -40,7 +44,7 @@ def FuseMulAndConv2D :
|
||||
(location $mul)),
|
||||
$strides, $use_cudnn, $padding, $explicit_padding, $data_format,
|
||||
$dilations, (location $conv)),
|
||||
[(BroadcastableElements $filter_value, $mul_value), (HasOneUse $conv)]>;
|
||||
[(CanFuseMulAndConv2D $filter_value, $mul_value), (HasOneUse $conv)]>;
|
||||
|
||||
// This rule does the following pattern match and rewrite:
|
||||
//
|
||||
|
@ -1304,5 +1304,54 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
str(error.exception))
|
||||
|
||||
|
||||
class AffineOpThenMulFusionTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
@parameterized.named_parameters(('should_fuse_1d', [2], True),
|
||||
('should_fuse_1x2', [1, 2], True),
|
||||
('should_not_fuse_2x1', [2, 1], False),
|
||||
('should_not_fuse_2x2', [2, 2], False))
|
||||
@test_util.run_v2_only
|
||||
def testFullyConnectedFusion(self, multiplier_shape, can_fuse):
|
||||
"""Test fusion of (x ∗ w) * m into fullyconnected."""
|
||||
|
||||
@tf.function
|
||||
def func(x):
|
||||
w = tf.constant([3., 4., 5., 6.], shape=[2, 2])
|
||||
m_value = [7., 8.] if sum(multiplier_shape) < 4 else [7., 8., 9., 10.]
|
||||
m = tf.constant(m_value, shape=multiplier_shape)
|
||||
return tf.matmul(x, w) * m
|
||||
|
||||
input_data = tf.constant([1., 2.], shape=[1, 2])
|
||||
self._checkAffineFusion(func, input_data, 1 if can_fuse else 2)
|
||||
|
||||
@parameterized.named_parameters(('should_fuse_1d', [2], True),
|
||||
('should_fuse_1x2', [1, 2], True),
|
||||
('should_not_fuse_2x1', [2, 1], False))
|
||||
@test_util.run_v2_only
|
||||
def testConvFusion(self, multiplier_shape, can_fuse):
|
||||
"""Test fusion of (x ∗ w) * m into conv2d."""
|
||||
|
||||
@tf.function
|
||||
def func(x):
|
||||
w = tf.constant([3., 4., 5., 6.], shape=[2, 1, 1, 2])
|
||||
m = tf.constant([7., 8.], shape=multiplier_shape)
|
||||
return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') * m
|
||||
|
||||
input_data = tf.constant([1., 2.], shape=[1, 1, 2, 1])
|
||||
self._checkAffineFusion(func, input_data, 1 if can_fuse else 2)
|
||||
|
||||
def _checkAffineFusion(self, func, input_data, expected_number_of_ops):
|
||||
concrete_func = func.get_concrete_function(input_data)
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
tflite_model = converter.convert()
|
||||
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
assert len(interpreter._get_ops_details()) == expected_number_of_ops
|
||||
|
||||
expected_value = func(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
|
||||
self.assertAllClose(expected_value.numpy(), actual_value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user