Fix previous build failure.
PiperOrigin-RevId: 306591715 Change-Id: I6a011b8ee753f7b541de5e3bf992d9eaab81caea
This commit is contained in:
		
							parent
							
								
									db65d28dbb
								
							
						
					
					
						commit
						e456af46e8
					
				@ -475,6 +475,28 @@ func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32
 | 
			
		||||
  // CHECK: return %[[RES]]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: @FuseFullyConnectedRelu6
 | 
			
		||||
func @FuseFullyConnectedRelu6(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
 | 
			
		||||
  %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
 | 
			
		||||
  %1 = "tfl.relu6"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
 | 
			
		||||
  return %1 : tensor<1x128xf32>
 | 
			
		||||
 | 
			
		||||
  // CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
 | 
			
		||||
  // CHECK-SAME: fused_activation_function = "RELU6"
 | 
			
		||||
  // CHECK: return %[[RES]]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: @FuseFullyConnectedRelu1
 | 
			
		||||
func @FuseFullyConnectedRelu1(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
 | 
			
		||||
  %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
 | 
			
		||||
  %1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32>
 | 
			
		||||
  return %1 : tensor<1x128xf32>
 | 
			
		||||
 | 
			
		||||
  // CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected"
 | 
			
		||||
  // CHECK-SAME: fused_activation_function = "RELU_N1_TO_1"
 | 
			
		||||
  // CHECK: return %[[RES]]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: @HardSwishPattern
 | 
			
		||||
func @HardSwishPattern(%arg0: tensor<1xf32>) -> tensor<1xf32> {
 | 
			
		||||
  %three = constant dense<3.> : tensor<f32>
 | 
			
		||||
 | 
			
		||||
@ -51,6 +51,9 @@ namespace TFL {
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// The actual Optimize Pass.
 | 
			
		||||
namespace {
 | 
			
		||||
constexpr char kRelu[] = "RELU";
 | 
			
		||||
constexpr char kRelu6[] = "RELU6";
 | 
			
		||||
constexpr char kRelu1[] = "RELU_N1_TO_1";
 | 
			
		||||
 | 
			
		||||
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
 | 
			
		||||
  if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
 | 
			
		||||
@ -300,10 +303,11 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// TODO(b/136285429): Move to tablegen when variadic is supported.
 | 
			
		||||
struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
 | 
			
		||||
  using OpRewritePattern<TFL::ReluOp>::OpRewritePattern;
 | 
			
		||||
template <typename ReluXOp, char const *Act>
 | 
			
		||||
struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
 | 
			
		||||
  using OpRewritePattern<ReluXOp>::OpRewritePattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(TFL::ReluOp relu_op,
 | 
			
		||||
  LogicalResult matchAndRewrite(ReluXOp relu_op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    Operation *input = relu_op.getOperand().getDefiningOp();
 | 
			
		||||
    if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
 | 
			
		||||
@ -311,7 +315,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
 | 
			
		||||
    if (fully_connected_op.fused_activation_function() != "NONE")
 | 
			
		||||
      return failure();
 | 
			
		||||
 | 
			
		||||
    auto new_activation_func = rewriter.getStringAttr("RELU");
 | 
			
		||||
    auto new_activation_func = rewriter.getStringAttr(Act);
 | 
			
		||||
    auto new_weights_format =
 | 
			
		||||
        rewriter.getStringAttr(fully_connected_op.weights_format());
 | 
			
		||||
    auto new_keep_num_dims =
 | 
			
		||||
@ -708,7 +712,10 @@ void Optimize::runOnFunction() {
 | 
			
		||||
  // we explore these potentially first and then fuse the binary ops with the
 | 
			
		||||
  // following ops in a second pattern match.
 | 
			
		||||
  TFL::populateWithGenerated(ctx, &patterns);
 | 
			
		||||
  patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
 | 
			
		||||
  patterns.insert<FuseFullyConnectedAndAdd,
 | 
			
		||||
                  FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
 | 
			
		||||
                  FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
 | 
			
		||||
                  FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
 | 
			
		||||
                  FuseFullyConnectedAndMul>(ctx);
 | 
			
		||||
  applyPatternsAndFoldGreedily(func, patterns);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user