Merge pull request #41790 from lgeiger:fix-matmul-fusion

PiperOrigin-RevId: 341026586
Change-Id: Idb578a4cf48c82abaad2f811612c4df4c2f2752c
This commit is contained in:
TensorFlower Gardener 2020-11-06 04:41:00 -08:00 committed by Lukas Geiger
parent 5e38165d02
commit 7a6fa7dbc0
6 changed files with 98 additions and 23 deletions

View File

@ -286,7 +286,7 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
return %1 : tensor<4x2xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: return %[[RES]] : tensor<4x2xf32>
@ -384,7 +384,7 @@ func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> te
return %1 : tensor<4x2xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
// CHECK: return %[[RES]] : tensor<4x2xf32>
}

View File

@ -438,26 +438,28 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
return failure();
if (fc_op.fused_activation_function() != "NONE") return failure();
// Broadcast the constant operand of Mul if it isn't compatible to the
// filter input. We only support broadcasting the operand along the depth
// dimension, when the operand's depth is 1.
Value new_const_val = constant_val;
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) {
auto original_shape = cst.getType().getShape();
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
original_shape.end());
normalized_shape.push_back(1);
auto new_cst = cst.reshape(RankedTensorType::get(
normalized_shape, cst.getType().getElementType()));
Type new_type = new_cst.getType();
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
return failure();
}
auto new_op =
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
new_const_val = new_op.getResult();
// Only fuse multiplier if all dimensions other than the depth dimension
// are equal to 1 since otherwise
// `matmul(x, filter) * cst != matmul(x, filter * cst)`
// even if `filter` and `cst` are be broadcastable.
auto shape = cst.getType().getShape();
if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
// Expand and transpose the multiplier since weights are using the
// OHWI data format in TFLite.
int64_t normalized_shape[2] = {element_size, 1};
auto new_cst = cst.reshape(RankedTensorType::get(
normalized_shape, cst.getType().getElementType()));
Type new_type = new_cst.getType();
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
return failure();
}
auto new_op =
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
Value new_const_val = new_op.getResult();
// Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
// TF::MulOp is used to fold the constant.
// TODO(b/139192933): switch to the TFL constant folding

View File

@ -78,5 +78,21 @@ bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b) {
return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
}
bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape) {
if (elements_shape.empty()) return true;
for (auto dim : elements_shape.drop_back(1)) {
if (dim != 1) return false;
}
return true;
}
bool IsDimensionsDegenerateExceptLastOne(Attribute val) {
if (auto ranked_type = val.getType().dyn_cast<RankedTensorType>()) {
return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape());
}
return false;
}
} // namespace TFL
} // namespace mlir

View File

@ -70,6 +70,10 @@ inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) {
/// Returns whether the given `a` and `b` have broadcast-compatible
/// types.
bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b);
// Returns true if every dimension of the attribute is 1 except the last one.
bool IsDimensionsDegenerateExceptLastOne(mlir::Attribute val);
// Returns true if every element is 1 except the last one.
bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape);
} // end namespace TFL
} // end namespace mlir

View File

@ -17,8 +17,12 @@ include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
def BroadcastableElements :
Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1)">>;
// Only fuse multiplier if all dimensions other than the channel dimension
// are equal to 1.
def CanFuseMulAndConv2D :
Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1) && TFL::IsDimensionsDegenerateExceptLastOne($1)">>;
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0.getDefiningOp())">>;
@ -40,7 +44,7 @@ def FuseMulAndConv2D :
(location $mul)),
$strides, $use_cudnn, $padding, $explicit_padding, $data_format,
$dilations, (location $conv)),
[(BroadcastableElements $filter_value, $mul_value), (HasOneUse $conv)]>;
[(CanFuseMulAndConv2D $filter_value, $mul_value), (HasOneUse $conv)]>;
// This rule does the following pattern match and rewrite:
//

View File

@ -1304,5 +1304,54 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
str(error.exception))
class AffineOpThenMulFusionTest(lite_v2_test_util.ModelTest):
@parameterized.named_parameters(('should_fuse_1d', [2], True),
('should_fuse_1x2', [1, 2], True),
('should_not_fuse_2x1', [2, 1], False),
('should_not_fuse_2x2', [2, 2], False))
@test_util.run_v2_only
def testFullyConnectedFusion(self, multiplier_shape, can_fuse):
"""Test fusion of (x w) * m into fullyconnected."""
@tf.function
def func(x):
w = tf.constant([3., 4., 5., 6.], shape=[2, 2])
m_value = [7., 8.] if sum(multiplier_shape) < 4 else [7., 8., 9., 10.]
m = tf.constant(m_value, shape=multiplier_shape)
return tf.matmul(x, w) * m
input_data = tf.constant([1., 2.], shape=[1, 2])
self._checkAffineFusion(func, input_data, 1 if can_fuse else 2)
@parameterized.named_parameters(('should_fuse_1d', [2], True),
('should_fuse_1x2', [1, 2], True),
('should_not_fuse_2x1', [2, 1], False))
@test_util.run_v2_only
def testConvFusion(self, multiplier_shape, can_fuse):
"""Test fusion of (x w) * m into conv2d."""
@tf.function
def func(x):
w = tf.constant([3., 4., 5., 6.], shape=[2, 1, 1, 2])
m = tf.constant([7., 8.], shape=multiplier_shape)
return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') * m
input_data = tf.constant([1., 2.], shape=[1, 1, 2, 1])
self._checkAffineFusion(func, input_data, 1 if can_fuse else 2)
def _checkAffineFusion(self, func, input_data, expected_number_of_ops):
concrete_func = func.get_concrete_function(input_data)
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
interpreter = Interpreter(model_content=tflite_model)
assert len(interpreter._get_ops_details()) == expected_number_of_ops
expected_value = func(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
self.assertAllClose(expected_value.numpy(), actual_value)
if __name__ == '__main__':
test.main()