Merge pull request #41790 from lgeiger:fix-matmul-fusion

PiperOrigin-RevId: 341026586
Change-Id: Idb578a4cf48c82abaad2f811612c4df4c2f2752c
This commit is contained in:
TensorFlower Gardener 2020-11-06 04:41:00 -08:00 committed by Lukas Geiger
parent 5e38165d02
commit 7a6fa7dbc0
6 changed files with 98 additions and 23 deletions

View File

@ -286,7 +286,7 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
return %1 : tensor<4x2xf32> return %1 : tensor<4x2xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> // CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> // CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: return %[[RES]] : tensor<4x2xf32> // CHECK: return %[[RES]] : tensor<4x2xf32>
@ -384,7 +384,7 @@ func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> te
return %1 : tensor<4x2xf32> return %1 : tensor<4x2xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> // CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
// CHECK: return %[[RES]] : tensor<4x2xf32> // CHECK: return %[[RES]] : tensor<4x2xf32>
} }

View File

@ -438,26 +438,28 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
return failure(); return failure();
if (fc_op.fused_activation_function() != "NONE") return failure(); if (fc_op.fused_activation_function() != "NONE") return failure();
// Broadcast the constant operand of Mul if it isn't compatible to the // Only fuse multiplier if all dimensions other than the depth dimension
// filter input. We only support broadcasting the operand along the depth // are equal to 1 since otherwise
// dimension, when the operand's depth is 1. // `matmul(x, filter) * cst != matmul(x, filter * cst)`
Value new_const_val = constant_val; // even if `filter` and `cst` are be broadcastable.
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) { auto shape = cst.getType().getShape();
auto original_shape = cst.getType().getShape(); if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
original_shape.end()); int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
normalized_shape.push_back(1); // Expand and transpose the multiplier since weights are using the
auto new_cst = cst.reshape(RankedTensorType::get( // OHWI data format in TFLite.
normalized_shape, cst.getType().getElementType())); int64_t normalized_shape[2] = {element_size, 1};
Type new_type = new_cst.getType(); auto new_cst = cst.reshape(RankedTensorType::get(
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) { normalized_shape, cst.getType().getElementType()));
return failure(); Type new_type = new_cst.getType();
} if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
auto new_op = return failure();
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
new_const_val = new_op.getResult();
} }
auto new_op =
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
Value new_const_val = new_op.getResult();
// Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands, // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
// TF::MulOp is used to fold the constant. // TF::MulOp is used to fold the constant.
// TODO(b/139192933): switch to the TFL constant folding // TODO(b/139192933): switch to the TFL constant folding

View File

@ -78,5 +78,21 @@ bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b) {
return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type(); return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
} }
bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape) {
if (elements_shape.empty()) return true;
for (auto dim : elements_shape.drop_back(1)) {
if (dim != 1) return false;
}
return true;
}
bool IsDimensionsDegenerateExceptLastOne(Attribute val) {
if (auto ranked_type = val.getType().dyn_cast<RankedTensorType>()) {
return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape());
}
return false;
}
} // namespace TFL } // namespace TFL
} // namespace mlir } // namespace mlir

View File

@ -70,6 +70,10 @@ inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) {
/// Returns whether the given `a` and `b` have broadcast-compatible /// Returns whether the given `a` and `b` have broadcast-compatible
/// types. /// types.
bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b); bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b);
// Returns true if every dimension of the attribute is 1 except the last one.
bool IsDimensionsDegenerateExceptLastOne(mlir::Attribute val);
// Returns true if every element is 1 except the last one.
bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape);
} // end namespace TFL } // end namespace TFL
} // end namespace mlir } // end namespace mlir

View File

@ -17,8 +17,12 @@ include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">; def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
def BroadcastableElements :
Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1)">>; // Only fuse multiplier if all dimensions other than the channel dimension
// are equal to 1.
def CanFuseMulAndConv2D :
Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1) && TFL::IsDimensionsDegenerateExceptLastOne($1)">>;
def F32ElementsAttr : ElementsAttrBase< def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">; CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0.getDefiningOp())">>; def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0.getDefiningOp())">>;
@ -40,7 +44,7 @@ def FuseMulAndConv2D :
(location $mul)), (location $mul)),
$strides, $use_cudnn, $padding, $explicit_padding, $data_format, $strides, $use_cudnn, $padding, $explicit_padding, $data_format,
$dilations, (location $conv)), $dilations, (location $conv)),
[(BroadcastableElements $filter_value, $mul_value), (HasOneUse $conv)]>; [(CanFuseMulAndConv2D $filter_value, $mul_value), (HasOneUse $conv)]>;
// This rule does the following pattern match and rewrite: // This rule does the following pattern match and rewrite:
// //

View File

@ -1304,5 +1304,54 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
str(error.exception)) str(error.exception))
class AffineOpThenMulFusionTest(lite_v2_test_util.ModelTest):
@parameterized.named_parameters(('should_fuse_1d', [2], True),
('should_fuse_1x2', [1, 2], True),
('should_not_fuse_2x1', [2, 1], False),
('should_not_fuse_2x2', [2, 2], False))
@test_util.run_v2_only
def testFullyConnectedFusion(self, multiplier_shape, can_fuse):
"""Test fusion of (x w) * m into fullyconnected."""
@tf.function
def func(x):
w = tf.constant([3., 4., 5., 6.], shape=[2, 2])
m_value = [7., 8.] if sum(multiplier_shape) < 4 else [7., 8., 9., 10.]
m = tf.constant(m_value, shape=multiplier_shape)
return tf.matmul(x, w) * m
input_data = tf.constant([1., 2.], shape=[1, 2])
self._checkAffineFusion(func, input_data, 1 if can_fuse else 2)
@parameterized.named_parameters(('should_fuse_1d', [2], True),
('should_fuse_1x2', [1, 2], True),
('should_not_fuse_2x1', [2, 1], False))
@test_util.run_v2_only
def testConvFusion(self, multiplier_shape, can_fuse):
"""Test fusion of (x w) * m into conv2d."""
@tf.function
def func(x):
w = tf.constant([3., 4., 5., 6.], shape=[2, 1, 1, 2])
m = tf.constant([7., 8.], shape=multiplier_shape)
return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') * m
input_data = tf.constant([1., 2.], shape=[1, 1, 2, 1])
self._checkAffineFusion(func, input_data, 1 if can_fuse else 2)
def _checkAffineFusion(self, func, input_data, expected_number_of_ops):
concrete_func = func.get_concrete_function(input_data)
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
interpreter = Interpreter(model_content=tflite_model)
assert len(interpreter._get_ops_details()) == expected_number_of_ops
expected_value = func(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
self.assertAllClose(expected_value.numpy(), actual_value)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()