Remove tf.Placeholder.input/tfl.pseudo_input ops and associated dependencies.
PiperOrigin-RevId: 279394545 Change-Id: I7100f3025073ed74f45b757f6e321bc6e7ede5c8
This commit is contained in:
parent
514cf2d96e
commit
5f5f9cf191
@ -221,18 +221,11 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
}
|
||||
}
|
||||
|
||||
static bool IsInput(Operation* op) {
|
||||
return isa<tfl::InputOp>(op) ||
|
||||
op->getName().getStringRef() == "tf.Placeholder.input";
|
||||
}
|
||||
|
||||
static bool IsConst(Operation* op) {
|
||||
return isa<mlir::ConstantOp>(op) || isa<mlir::TF::ConstOp>(op) ||
|
||||
isa<tfl::ConstOp>(op) || isa<tfl::QConstOp>(op);
|
||||
}
|
||||
|
||||
static bool IsConstOrInput(Operation* op) { return IsConst(op) || IsInput(op); }
|
||||
|
||||
template <typename T>
|
||||
static bool HasValidTFLiteType(Value* value, T& error_handler) {
|
||||
// None type is allowed to represent unspecified operands.
|
||||
@ -957,8 +950,8 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
}
|
||||
}
|
||||
|
||||
// Skip constant and input ops as they don't represent a TFLite operator.
|
||||
if (IsConstOrInput(&inst)) continue;
|
||||
// Skip constant ops as they don't represent a TFLite operator.
|
||||
if (IsConst(&inst)) continue;
|
||||
|
||||
// Fetch operand and result tensor indices.
|
||||
std::vector<int32_t> operands;
|
||||
|
@ -1172,24 +1172,6 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
|
||||
// NoSideEffect trait is not added to the op intentionally to prevent it from
|
||||
// getting removed if the input is unused. The generated FlatBuffer needs to
|
||||
// have a tensor along with the metadata for each of the subgraph inputs.
|
||||
def TFL_InputOp : Op<TFL_Dialect, "pseudo_input", [SameOperandsAndResultType]> {
|
||||
let summary = "Input pseudo operator";
|
||||
|
||||
let description = [{
|
||||
Takes one of the function arguments as input and returns it as result. This
|
||||
is a NOP and is used to attach attributes such as tensor name to function
|
||||
arguments.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$input);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
|
||||
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
|
||||
// zero_point = central_value
|
||||
|
@ -41,13 +41,10 @@ versions {
|
||||
|
||||
# MLIR-LABEL: func @main(%arg0: tensor<4x!quant.uniform<u8:f32, 0.015686274509803921:128>>, %arg1: tensor<4x!quant.uniform<u8:f32, 0.023529411764705882:128>>) -> tensor<4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
# MLIR-NEXT: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
||||
# MLIR-NEXT: %[[input0:.*]] = "tfl.pseudo_input"(%arg1) : (tensor<4x!quant.uniform<u8:f32, 0.023529411764705882:128>>) -> tensor<4x!quant.uniform<u8:f32, 0.023529411764705882:128>>
|
||||
# MLIR-NEXT: %[[input1:.*]] = "tfl.pseudo_input"(%arg0) : (tensor<4x!quant.uniform<u8:f32, 0.015686274509803921:128>>) -> tensor<4x!quant.uniform<u8:f32, 0.015686274509803921:128>>
|
||||
# MLIR-NEXT: %[[add:.*]] = "tfl.add"(%[[input1]], %[[input0]]) {fused_activation_function = "NONE"} : (tensor<4x!quant.uniform<u8:f32, 0.015686274509803921:128>>, tensor<4x!quant.uniform<u8:f32, 0.023529411764705882:128>>) -> tensor<4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
# MLIR-NEXT: %[[add:.*]] = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<4x!quant.uniform<u8:f32, 0.015686274509803921:128>>, tensor<4x!quant.uniform<u8:f32, 0.023529411764705882:128>>) -> tensor<4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
# MLIR-NEXT: return %[[add]] : tensor<4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
# MLIR-NEXT: }
|
||||
|
||||
|
||||
# CHECK-LABEL: {
|
||||
# CHECK-NEXT: version: 3,
|
||||
# CHECK-NEXT: operator_codes: [ {
|
||||
|
@ -168,22 +168,6 @@ func @const() -> tensor<2xi32> {
|
||||
// CHECK: "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
}
|
||||
|
||||
func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor<f32>) -> tensor<f32>
|
||||
return %0: tensor<f32>
|
||||
|
||||
// CHECK-LABEL: @placeholder
|
||||
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
}
|
||||
|
||||
func @placeholder_int(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0: tensor<i32>
|
||||
|
||||
// CHECK-LABEL: @placeholder_int
|
||||
// CHECK-NEXT: "tfl.pseudo_input"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
}
|
||||
|
||||
func @shape(%arg0: tensor<?x1001xf32>) -> tensor<2xi32> {
|
||||
%0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<?x1001xf32>) -> tensor<2xi32>
|
||||
%1 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT"} : (tensor<?x1001xf32>) -> tensor<2xi32>
|
||||
|
@ -5,11 +5,10 @@
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
%0 = "tfl.pseudo_input" (%arg0) {name = "Input"} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "tfl.pseudo_const" () {name = "Const", value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE", name = "mul"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "tfl.pseudo_const" () {name = "Const", value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
%1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE", name = "mul"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// tf.div is the result of conversion to a Flex TF op
|
||||
%3 = "tf.Div"(%2, %1) {name = "div"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%4 = "tfl.exp"(%3) {name = "exp"} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return %4 : tensor<4xf32>
|
||||
%2 = "tf.Div"(%1, %0) {name = "div"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%3 = "tfl.exp"(%2) {name = "exp"} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return %3 : tensor<4xf32>
|
||||
}
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
^bb0(%arg0: tensor<3x2xi32>):
|
||||
%0 = "tfl.pseudo_input" (%arg0) {name = "Input"} : (tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
// CHECK: error: 'unknown_op' op dialect is not registered
|
||||
%1 = "unknown_op"(%0) : (tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
return %1 : tensor<3x2xi32>
|
||||
%0 = "unknown_op"(%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
return %0 : tensor<3x2xi32>
|
||||
}
|
||||
|
@ -301,38 +301,30 @@ func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddUnit
|
||||
func @FuseFullyConnectedAddUnit(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant unit
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32>
|
||||
%1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<40x40xf32>
|
||||
|
||||
%2 = "tfl.fully_connected" (%0, %1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
|
||||
%3 = "tfl.add"(%2, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
|
||||
return %3 : tensor<40x40xf32>
|
||||
return %1 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<2.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32>
|
||||
// CHECK: %1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32>
|
||||
// CHECK: %2 = "tfl.fully_connected"(%0, %1, %cst)
|
||||
// CHECK: return %2
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %cst)
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddConst
|
||||
func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<3.0> : tensor<40x40xf32>
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input")
|
||||
%1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input")
|
||||
%cst2 = constant dense<2.0> : tensor<40x40xf32>
|
||||
|
||||
%2 = "tfl.fully_connected" (%0, %1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
|
||||
%3 = "tfl.add"(%2, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
|
||||
return %3 : tensor<40x40xf32>
|
||||
return %1 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %[[cst_0:.*]] = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[cst_1:.*1]] = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%[[cst_0]], %[[cst_1]], %[[cst]])
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
// CHECK-LABEL: RemoveUnused
|
||||
func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
|
||||
%in = "tfl.pseudo_input"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
%cst = "tfl.pseudo_input"(%arg1) : (tensor<i32>) -> tensor<i32>
|
||||
%0 = "tfl.quantize"(%in) {qtype = tensor<4x!quant.uniform<u8:f32, 1.0>>} : (tensor<4xf32>) -> tensor<4x!quant.uniform<u8:f32, 1.0>>
|
||||
%1:4 = "tfl.split"(%cst, %0) {num_splits = 4 : i32} : (tensor<i32>, tensor<4x!quant.uniform<u8:f32, 1.0>>)
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<4x!quant.uniform<u8:f32, 1.0>>} : (tensor<4xf32>) -> tensor<4x!quant.uniform<u8:f32, 1.0>>
|
||||
%1:4 = "tfl.split"(%arg1, %0) {num_splits = 4 : i32} : (tensor<i32>, tensor<4x!quant.uniform<u8:f32, 1.0>>)
|
||||
-> (tensor<2x!quant.uniform<u8:f32, 1.0>>, tensor<2x!quant.uniform<u8:f32, 1.0>>,tensor<2x!quant.uniform<u8:f32, 1.0>>, tensor<2x!quant.uniform<u8:f32, 1.0>>)
|
||||
%2 = "tfl.dequantize"(%1#0) : (tensor<2x!quant.uniform<u8:f32, 1.0>>) -> tensor<2xf32>
|
||||
%3 = "tfl.dequantize"(%1#1) : (tensor<2x!quant.uniform<u8:f32, 1.0>>) -> tensor<2xf32>
|
||||
@ -17,49 +15,41 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,t
|
||||
|
||||
return %2, %3 : tensor<2xf32>, tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: %[[in1:.*]] = "tfl.pseudo_input"(%arg0)
|
||||
// CHECK-NEXT: %[[in2:.*]] = "tfl.pseudo_input"(%arg1)
|
||||
// CHECK-NEXT: %[[split:.*]]:4 = "tfl.split"(%[[in2]], %[[in1]])
|
||||
// CHECK-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %arg0)
|
||||
// CHECK-NEXT: return %[[split]]#0, %[[split]]#1
|
||||
}
|
||||
|
||||
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
|
||||
%cst = constant dense<[1, 1001]> : tensor<2xi32>
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
|
||||
%2 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
|
||||
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>
|
||||
%4 = "tfl.conv_2d"(%1, %2, %3) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
%5 = "tfl.reshape"(%4, %cst) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>, tensor<2xi32>) -> tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
%6 = "tfl.softmax"(%5) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
|
||||
%7 = "tfl.dequantize"(%6) : (tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<1x1001xf32>
|
||||
return %7 : tensor<1x1001xf32>
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
|
||||
%1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
|
||||
%2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>
|
||||
%3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
%4 = "tfl.reshape"(%3, %cst) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>, tensor<2xi32>) -> tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
%5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
|
||||
%6 = "tfl.dequantize"(%5) : (tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<1x1001xf32>
|
||||
return %6 : tensor<1x1001xf32>
|
||||
}
|
||||
|
||||
func @main2(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
%2 = "tfl.pseudo_input"(%arg1) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
%3 = "tfl.quantize"(%2) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
%4 = tfl.add %1, %3 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
%5 = "tfl.dequantize"(%4) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4xf32>
|
||||
return %5 : tensor<2x4xf32>
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
%1 = "tfl.quantize"(%arg1) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
%2 = tfl.add %0, %1 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
%3 = "tfl.dequantize"(%2) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4xf32>
|
||||
return %3 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @main(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
|
||||
// CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<[1, 1001]> : tensor<2xi32>
|
||||
// CHECK-NEXT: %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
|
||||
// CHECK-NEXT: %2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>}
|
||||
// CHECK-NEXT: %3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
|
||||
// CHECK-NEXT: %4 = "tfl.reshape"(%3, %[[cst]]) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>, tensor<2xi32>)
|
||||
// CHECK-NEXT: %5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>)
|
||||
// CHECK-NEXT: return %5 : tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
|
||||
// CHECK-NEXT: %[[q_cst_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
|
||||
// CHECK-NEXT: %[[q_cst_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>}
|
||||
// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[q_cst_0]], %[[q_cst_1]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
|
||||
// CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[cst]]) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>, tensor<2xi32>)
|
||||
// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>)
|
||||
// CHECK-NEXT: return %[[softmax]] : tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// CHECK: func @main2(%arg0: tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>, %arg1: tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>> {
|
||||
// CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg1) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT: %1 = "tfl.pseudo_input"(%arg0) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT: %2 = tfl.add %1, %0 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT: return %2 : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT: %[[add:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT: return %[[add]] : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
|
||||
// CHECK-NEXT:}
|
||||
|
@ -2,20 +2,13 @@
|
||||
|
||||
// CHECK-LABEL: quantize_float_placeholder_only
|
||||
func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>) {
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
%1 = "tfl.pseudo_input"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%2 = "tfl.pseudo_input"(%arg2) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %arg0, %arg1, %arg2: tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>
|
||||
|
||||
return %0, %1, %2: tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>
|
||||
|
||||
// CHECK-NEXT: %[[in:.*]] = "tfl.pseudo_input"(%arg0)
|
||||
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[in]])
|
||||
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0)
|
||||
// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||
// CHECK-NEXT: %[[in_1:.*]] = "tfl.pseudo_input"(%arg1)
|
||||
// CHECK-NEXT: %[[in_0:.*]] = "tfl.pseudo_input"(%arg2)
|
||||
// CHECK-NEXT: %[[q_0:.*]] = "tfl.quantize"(%[[in_0]])
|
||||
// CHECK-NEXT: %[[q_0:.*]] = "tfl.quantize"(%arg2)
|
||||
// CHECK-NEXT: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
|
||||
// CHECK-NEXT: %[[dq]], %[[in_1]], %[[dq_0]]
|
||||
// CHECK-NEXT: %[[dq]], %arg1, %[[dq_0]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DequantizeAndQuantize
|
||||
|
@ -2,30 +2,26 @@
|
||||
|
||||
func @testSingleLstm(%arg0: tensor<4 x f32>) -> tensor<4xf32> {
|
||||
// CHECK-LABEL: testSingleLstm
|
||||
// CHECK: %[[INPUT:[a-z0-9]*]] = "tfl.pseudo_input"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[CST_0]], %[[CST_1]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]]) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32>
|
||||
%1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %1, %1, %0, %0, %0, %0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %2 : tensor<4xf32>
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
func @testMultipleLstms(%arg0: tensor<4 x f32>) -> tensor<4xf32> {
|
||||
// CHECK-LABEL: testMultipleLstms
|
||||
// CHECK: %[[INPUT:[a-z0-9]*]] = "tfl.pseudo_input"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[CST_0]], %[[CST_1]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]]) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[CST_2]], %[[CST_3]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]]) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32>
|
||||
%1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %1, %1, %0, %0, %0, %0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%3 = "tfl.unidirectional_sequence_lstm"(%2, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %1, %1, %0, %0, %0, %0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %3 : tensor<4xf32>
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %2 : tensor<4xf32>
|
||||
}
|
||||
|
@ -1,21 +1,18 @@
|
||||
// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-whitelist="bar,foobar" %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
return %arg0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
func @bar(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
return %0 : tensor<2x4xf32>
|
||||
return %arg0 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
func @foobar(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
return %arg0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-DAG: func @main
|
||||
// CHECK-DAG: func @foobar
|
||||
// CHECK-NOT: func @foo
|
||||
// CHECK-NOT: func @bar
|
||||
// CHECK-NOT: func @bar
|
||||
|
@ -159,8 +159,6 @@ def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
|
||||
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
|
||||
def : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
|
||||
def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
|
||||
def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type),
|
||||
(TFL_InputOp $inputs)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary ops patterns.
|
||||
|
@ -69,50 +69,22 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
|
||||
// argument in the list.
|
||||
auto* arg = bb.getArgument(0);
|
||||
|
||||
auto remove_quantize_op = [&](QuantizeOp quantize_op, InputOp input_op) {
|
||||
auto remove_quantize_op = [&](QuantizeOp quantize_op) {
|
||||
auto quantize_output = quantize_op.output();
|
||||
auto quantize_type = quantize_output->getType();
|
||||
input_types.push_back(quantize_type);
|
||||
auto* new_arg = bb.addArgument(quantize_type);
|
||||
|
||||
if (input_op) {
|
||||
// Make a copy of input op with quantized input and output type.
|
||||
auto new_input =
|
||||
builder.create<InputOp>(input_op.getLoc(), quantize_type, new_arg);
|
||||
quantize_output->replaceAllUsesWith(new_input);
|
||||
} else {
|
||||
quantize_output->replaceAllUsesWith(new_arg);
|
||||
}
|
||||
|
||||
quantize_output->replaceAllUsesWith(new_arg);
|
||||
quantize_op.erase();
|
||||
if (input_op) input_op.erase();
|
||||
arg->dropAllUses();
|
||||
bb.eraseArgument(0);
|
||||
};
|
||||
|
||||
// This is looking for a pattern: arg -> tfl.pseudo_input -> tfl.quantize
|
||||
// or arg -> tfl.quantize
|
||||
if (arg->hasOneUse()) {
|
||||
if (llvm::isa<QuantizeOp>(*arg->user_begin())) {
|
||||
auto quantize_op = llvm::cast<QuantizeOp>(*arg->user_begin());
|
||||
remove_quantize_op(quantize_op, /*input_op=*/nullptr);
|
||||
continue;
|
||||
} else if (llvm::isa<InputOp>(*arg->user_begin())) {
|
||||
// TODO(lyandy): Remove arg -> tfl.pseudo_input -> tfl.quantize once
|
||||
// tfl.pseudo_input are not generated.
|
||||
auto input_op = llvm::cast<InputOp>(*arg->user_begin());
|
||||
Value* input_output = input_op.output();
|
||||
// We can drop the quantization adaptor only when the pseudo input op
|
||||
// has one user and it is the quantize op. Otherwise, we have to keep
|
||||
// the adaptor and allow the floating point inputs.
|
||||
if (input_output->hasOneUse() &&
|
||||
llvm::isa<QuantizeOp>(*input_output->user_begin())) {
|
||||
auto quantize_op =
|
||||
llvm::cast<QuantizeOp>(*input_output->user_begin());
|
||||
remove_quantize_op(quantize_op, input_op);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// This is looking for a pattern: arg -> tfl.quantize
|
||||
if (arg->hasOneUse() && llvm::isa<QuantizeOp>(*arg->user_begin())) {
|
||||
auto quantize_op = llvm::cast<QuantizeOp>(*arg->user_begin());
|
||||
remove_quantize_op(quantize_op);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Make a copy of current argument and append it to the end of the list if
|
||||
|
@ -155,19 +155,9 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
|
||||
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
|
||||
BlockArgument* arg = func.getArgument(i);
|
||||
if (arg->hasOneUse() && llvm::isa<TFL::InputOp>(*arg->getUsers().begin())) {
|
||||
// TODO(lyandy): Remove arg -> tfl.pseudo_input -> tfl.quantize once
|
||||
// tfl.pseudo_input are not generated.
|
||||
Operation* input = *arg->getUsers().begin();
|
||||
auto input_op = llvm::cast<TFL::InputOp>(input);
|
||||
add_quantize_op(input_op.getLoc(), input_op.input()->getType(),
|
||||
input->getBlock(), ++Block::iterator(input_op),
|
||||
input_op.output(), i);
|
||||
} else {
|
||||
auto* arg_block = arg->getOwner();
|
||||
add_quantize_op(arg->getLoc(), arg->getType(), arg_block,
|
||||
std::next(arg_block->begin(), i), arg, i);
|
||||
}
|
||||
auto* arg_block = arg->getOwner();
|
||||
add_quantize_op(arg->getLoc(), arg->getType(), arg_block,
|
||||
std::next(arg_block->begin(), i), arg, i);
|
||||
}
|
||||
|
||||
return false;
|
||||
|
@ -241,33 +241,6 @@ underlying graph, and executes each of the partitioned subgraphs as a function.
|
||||
}];
|
||||
}
|
||||
|
||||
// In MLIR, the 'tf.Placeholder.input' instruction is used to capture attributes
|
||||
// of function arguments.
|
||||
// Note: NoSideEffect trait is not added intentionally to preserve the captured
|
||||
// attributes even if the input is unused.
|
||||
def TF_PlaceholderInputOp : TF_Op<"Placeholder.input",
|
||||
[SameOperandsAndResultType]> {
|
||||
let summary = "PlaceholderInput op";
|
||||
|
||||
let description = [{
|
||||
Inserts a placeholder for a tensor that will be always fed.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$arg,
|
||||
|
||||
OptionalAttr<F32Attr>:$min,
|
||||
OptionalAttr<F32Attr>:$max,
|
||||
OptionalAttr<TF_IntTypeAttr>:$type
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_PlaceholderOp : TF_Op<"Placeholder", [NoSideEffect]> {
|
||||
let summary = "Placeholder op";
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user