Remove tf.Placeholder.input/tfl.pseudo_input ops and associated dependencies.

PiperOrigin-RevId: 279394545
Change-Id: I7100f3025073ed74f45b757f6e321bc6e7ede5c8
This commit is contained in:
Andy Ly 2019-11-08 14:43:51 -08:00 committed by TensorFlower Gardener
parent 514cf2d96e
commit 5f5f9cf191
15 changed files with 71 additions and 216 deletions

View File

@ -221,18 +221,11 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
}
}
static bool IsInput(Operation* op) {
return isa<tfl::InputOp>(op) ||
op->getName().getStringRef() == "tf.Placeholder.input";
}
static bool IsConst(Operation* op) {
return isa<mlir::ConstantOp>(op) || isa<mlir::TF::ConstOp>(op) ||
isa<tfl::ConstOp>(op) || isa<tfl::QConstOp>(op);
}
static bool IsConstOrInput(Operation* op) { return IsConst(op) || IsInput(op); }
template <typename T>
static bool HasValidTFLiteType(Value* value, T& error_handler) {
// None type is allowed to represent unspecified operands.
@ -957,8 +950,8 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
}
}
// Skip constant and input ops as they don't represent a TFLite operator.
if (IsConstOrInput(&inst)) continue;
// Skip constant ops as they don't represent a TFLite operator.
if (IsConst(&inst)) continue;
// Fetch operand and result tensor indices.
std::vector<int32_t> operands;

View File

@ -1172,24 +1172,6 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
let hasOptions = 0;
}
// NoSideEffect trait is not added to the op intentionally to prevent it from
// getting removed if the input is unused. The generated FlatBuffer needs to
// have a tensor along with the metadata for each of the subgraph inputs.
def TFL_InputOp : Op<TFL_Dialect, "pseudo_input", [SameOperandsAndResultType]> {
let summary = "Input pseudo operator";
let description = [{
Takes one of the function arguments as input and returns it as result. This
is a NOP and is used to attach attributes such as tensor name to function
arguments.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
}
def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
// zero_point = central_value

View File

@ -41,13 +41,10 @@ versions {
# MLIR-LABEL: func @main(%arg0: tensor<4x!quant.uniform<u8:f32, 0.015686274509803921:128>>, %arg1: tensor<4x!quant.uniform<u8:f32, 0.023529411764705882:128>>) -> tensor<4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
# MLIR-NEXT: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
# MLIR-NEXT: %[[input0:.*]] = "tfl.pseudo_input"(%arg1) : (tensor<4x!quant.uniform<u8:f32, 0.023529411764705882:128>>) -> tensor<4x!quant.uniform<u8:f32, 0.023529411764705882:128>>
# MLIR-NEXT: %[[input1:.*]] = "tfl.pseudo_input"(%arg0) : (tensor<4x!quant.uniform<u8:f32, 0.015686274509803921:128>>) -> tensor<4x!quant.uniform<u8:f32, 0.015686274509803921:128>>
# MLIR-NEXT: %[[add:.*]] = "tfl.add"(%[[input1]], %[[input0]]) {fused_activation_function = "NONE"} : (tensor<4x!quant.uniform<u8:f32, 0.015686274509803921:128>>, tensor<4x!quant.uniform<u8:f32, 0.023529411764705882:128>>) -> tensor<4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
# MLIR-NEXT: %[[add:.*]] = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<4x!quant.uniform<u8:f32, 0.015686274509803921:128>>, tensor<4x!quant.uniform<u8:f32, 0.023529411764705882:128>>) -> tensor<4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
# MLIR-NEXT: return %[[add]] : tensor<4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
# MLIR-NEXT: }
# CHECK-LABEL: {
# CHECK-NEXT: version: 3,
# CHECK-NEXT: operator_codes: [ {

View File

@ -168,22 +168,6 @@ func @const() -> tensor<2xi32> {
// CHECK: "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
}
func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor<f32>) -> tensor<f32>
return %0: tensor<f32>
// CHECK-LABEL: @placeholder
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
}
func @placeholder_int(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "tf.Placeholder.input"(%arg0) {name = "Input"} : (tensor<i32>) -> tensor<i32>
return %0: tensor<i32>
// CHECK-LABEL: @placeholder_int
// CHECK-NEXT: "tfl.pseudo_input"(%arg0) : (tensor<i32>) -> tensor<i32>
}
func @shape(%arg0: tensor<?x1001xf32>) -> tensor<2xi32> {
%0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<?x1001xf32>) -> tensor<2xi32>
%1 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT"} : (tensor<?x1001xf32>) -> tensor<2xi32>

View File

@ -5,11 +5,10 @@
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
%0 = "tfl.pseudo_input" (%arg0) {name = "Input"} : (tensor<4xf32>) -> tensor<4xf32>
%1 = "tfl.pseudo_const" () {name = "Const", value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE", name = "mul"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "tfl.pseudo_const" () {name = "Const", value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
%1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE", name = "mul"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// tf.div is the result of conversion to a Flex TF op
%3 = "tf.Div"(%2, %1) {name = "div"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%4 = "tfl.exp"(%3) {name = "exp"} : (tensor<4xf32>) -> tensor<4xf32>
return %4 : tensor<4xf32>
%2 = "tf.Div"(%1, %0) {name = "div"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%3 = "tfl.exp"(%2) {name = "exp"} : (tensor<4xf32>) -> tensor<4xf32>
return %3 : tensor<4xf32>
}

View File

@ -2,8 +2,7 @@
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
^bb0(%arg0: tensor<3x2xi32>):
%0 = "tfl.pseudo_input" (%arg0) {name = "Input"} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// CHECK: error: 'unknown_op' op dialect is not registered
%1 = "unknown_op"(%0) : (tensor<3x2xi32>) -> tensor<3x2xi32>
return %1 : tensor<3x2xi32>
%0 = "unknown_op"(%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32>
return %0 : tensor<3x2xi32>
}

View File

@ -301,38 +301,30 @@ func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x
// CHECK-LABEL: @FuseFullyConnectedAddUnit
func @FuseFullyConnectedAddUnit(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant unit
%0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32>
%1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32>
%cst2 = constant dense<2.0> : tensor<40x40xf32>
%2 = "tfl.fully_connected" (%0, %1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
%3 = "tfl.add"(%2, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
return %3 : tensor<40x40xf32>
return %1 : tensor<40x40xf32>
// CHECK: %cst = constant dense<2.000000e+00> : tensor<40x40xf32>
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32>
// CHECK: %1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32>
// CHECK: %2 = "tfl.fully_connected"(%0, %1, %cst)
// CHECK: return %2
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %cst)
// CHECK: return %[[fc]]
}
// CHECK-LABEL: @FuseFullyConnectedAddConst
func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant dense<3.0> : tensor<40x40xf32>
%0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input")
%1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input")
%cst2 = constant dense<2.0> : tensor<40x40xf32>
%2 = "tfl.fully_connected" (%0, %1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
%3 = "tfl.add"(%2, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
return %3 : tensor<40x40xf32>
return %1 : tensor<40x40xf32>
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
// CHECK: %[[cst_0:.*]] = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32>
// CHECK: %[[cst_1:.*1]] = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%[[cst_0]], %[[cst_1]], %[[cst]])
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// CHECK: return %[[fc]]
}

View File

@ -2,10 +2,8 @@
// CHECK-LABEL: RemoveUnused
func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,tensor<2xf32>) {
%in = "tfl.pseudo_input"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
%cst = "tfl.pseudo_input"(%arg1) : (tensor<i32>) -> tensor<i32>
%0 = "tfl.quantize"(%in) {qtype = tensor<4x!quant.uniform<u8:f32, 1.0>>} : (tensor<4xf32>) -> tensor<4x!quant.uniform<u8:f32, 1.0>>
%1:4 = "tfl.split"(%cst, %0) {num_splits = 4 : i32} : (tensor<i32>, tensor<4x!quant.uniform<u8:f32, 1.0>>)
%0 = "tfl.quantize"(%arg0) {qtype = tensor<4x!quant.uniform<u8:f32, 1.0>>} : (tensor<4xf32>) -> tensor<4x!quant.uniform<u8:f32, 1.0>>
%1:4 = "tfl.split"(%arg1, %0) {num_splits = 4 : i32} : (tensor<i32>, tensor<4x!quant.uniform<u8:f32, 1.0>>)
-> (tensor<2x!quant.uniform<u8:f32, 1.0>>, tensor<2x!quant.uniform<u8:f32, 1.0>>,tensor<2x!quant.uniform<u8:f32, 1.0>>, tensor<2x!quant.uniform<u8:f32, 1.0>>)
%2 = "tfl.dequantize"(%1#0) : (tensor<2x!quant.uniform<u8:f32, 1.0>>) -> tensor<2xf32>
%3 = "tfl.dequantize"(%1#1) : (tensor<2x!quant.uniform<u8:f32, 1.0>>) -> tensor<2xf32>
@ -17,49 +15,41 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,t
return %2, %3 : tensor<2xf32>, tensor<2xf32>
// CHECK-NEXT: %[[in1:.*]] = "tfl.pseudo_input"(%arg0)
// CHECK-NEXT: %[[in2:.*]] = "tfl.pseudo_input"(%arg1)
// CHECK-NEXT: %[[split:.*]]:4 = "tfl.split"(%[[in2]], %[[in1]])
// CHECK-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %arg0)
// CHECK-NEXT: return %[[split]]#0, %[[split]]#1
}
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
%cst = constant dense<[1, 1001]> : tensor<2xi32>
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
%2 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>
%4 = "tfl.conv_2d"(%1, %2, %3) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
%5 = "tfl.reshape"(%4, %cst) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>, tensor<2xi32>) -> tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>
%6 = "tfl.softmax"(%5) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
%7 = "tfl.dequantize"(%6) : (tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<1x1001xf32>
return %7 : tensor<1x1001xf32>
%0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
%1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
%2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>
%3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
%4 = "tfl.reshape"(%3, %cst) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>, tensor<2xi32>) -> tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>
%5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>) -> tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
%6 = "tfl.dequantize"(%5) : (tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>) -> tensor<1x1001xf32>
return %6 : tensor<1x1001xf32>
}
func @main2(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf32> {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
%2 = "tfl.pseudo_input"(%arg1) : (tensor<2x4xf32>) -> tensor<2x4xf32>
%3 = "tfl.quantize"(%2) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
%4 = tfl.add %1, %3 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
%5 = "tfl.dequantize"(%4) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4xf32>
return %5 : tensor<2x4xf32>
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
%1 = "tfl.quantize"(%arg1) {qtype = tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
%2 = tfl.add %0, %1 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
%3 = "tfl.dequantize"(%2) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4xf32>
return %3 : tensor<2x4xf32>
}
// CHECK: func @main(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>)
// CHECK-NEXT: %[[cst:.*]] = constant dense<[1, 1001]> : tensor<2xi32>
// CHECK-NEXT: %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
// CHECK-NEXT: %2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>}
// CHECK-NEXT: %3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
// CHECK-NEXT: %4 = "tfl.reshape"(%3, %[[cst]]) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>, tensor<2xi32>)
// CHECK-NEXT: %5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>)
// CHECK-NEXT: return %5 : tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
// CHECK-NEXT: %[[q_cst_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
// CHECK-NEXT: %[[q_cst_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<0> : tensor<32xi32>}
// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[q_cst_0]], %[[q_cst_1]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
// CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[cst]]) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>, tensor<2xi32>)
// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform<u8:f32, 0.023528476789885875>>)
// CHECK-NEXT: return %[[softmax]] : tensor<1x1001x!quant.uniform<u8:f32, 3.906250e-03>>
// CHECK-NEXT:}
// CHECK: func @main2(%arg0: tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>, %arg1: tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>> {
// CHECK-NEXT: %0 = "tfl.pseudo_input"(%arg1) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT: %1 = "tfl.pseudo_input"(%arg0) : (tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>) -> tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT: %2 = tfl.add %1, %0 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT: return %2 : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT: %[[add:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT: return %[[add]] : tensor<2x4x!quant.uniform<u8:f32, 0.49803921568627452>>
// CHECK-NEXT:}

View File

@ -2,20 +2,13 @@
// CHECK-LABEL: quantize_float_placeholder_only
func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>) {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<f32>) -> tensor<f32>
%1 = "tfl.pseudo_input"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%2 = "tfl.pseudo_input"(%arg2) : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %arg0, %arg1, %arg2: tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>
return %0, %1, %2: tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>
// CHECK-NEXT: %[[in:.*]] = "tfl.pseudo_input"(%arg0)
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[in]])
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0)
// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK-NEXT: %[[in_1:.*]] = "tfl.pseudo_input"(%arg1)
// CHECK-NEXT: %[[in_0:.*]] = "tfl.pseudo_input"(%arg2)
// CHECK-NEXT: %[[q_0:.*]] = "tfl.quantize"(%[[in_0]])
// CHECK-NEXT: %[[q_0:.*]] = "tfl.quantize"(%arg2)
// CHECK-NEXT: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]])
// CHECK-NEXT: %[[dq]], %[[in_1]], %[[dq_0]]
// CHECK-NEXT: %[[dq]], %arg1, %[[dq_0]]
}
// CHECK-LABEL: DequantizeAndQuantize

View File

@ -2,30 +2,26 @@
func @testSingleLstm(%arg0: tensor<4 x f32>) -> tensor<4xf32> {
// CHECK-LABEL: testSingleLstm
// CHECK: %[[INPUT:[a-z0-9]*]] = "tfl.pseudo_input"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[CST_0]], %[[CST_1]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]]) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32>
%1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%2 = "tfl.unidirectional_sequence_lstm"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %1, %1, %0, %0, %0, %0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %2 : tensor<4xf32>
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}
func @testMultipleLstms(%arg0: tensor<4 x f32>) -> tensor<4xf32> {
// CHECK-LABEL: testMultipleLstms
// CHECK: %[[INPUT:[a-z0-9]*]] = "tfl.pseudo_input"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[CST_0]], %[[CST_1]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]]) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[CST_2]], %[[CST_3]], %[[INPUT]], %[[INPUT]], %[[INPUT]], %[[INPUT]]) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32>
%1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%2 = "tfl.unidirectional_sequence_lstm"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %1, %1, %0, %0, %0, %0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%3 = "tfl.unidirectional_sequence_lstm"(%2, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %0, %1, %1, %0, %0, %0, %0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %3 : tensor<4xf32>
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %2 : tensor<4xf32>
}

View File

@ -1,21 +1,18 @@
// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-whitelist="bar,foobar" %s | FileCheck %s --dump-input-on-failure
func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
return %arg0 : tensor<1x4xf32>
}
func @bar(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf32> {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4xf32>
return %0 : tensor<2x4xf32>
return %arg0 : tensor<2x4xf32>
}
func @foobar(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
return %arg0 : tensor<1x4xf32>
}
// CHECK-DAG: func @main
// CHECK-DAG: func @foobar
// CHECK-NOT: func @foo
// CHECK-NOT: func @bar
// CHECK-NOT: func @bar

View File

@ -159,8 +159,6 @@ def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
def : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
def : Pat<(TF_PlaceholderInputOp $inputs, $min, $max, $type),
(TFL_InputOp $inputs)>;
//===----------------------------------------------------------------------===//
// Binary ops patterns.

View File

@ -69,50 +69,22 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
// argument in the list.
auto* arg = bb.getArgument(0);
auto remove_quantize_op = [&](QuantizeOp quantize_op, InputOp input_op) {
auto remove_quantize_op = [&](QuantizeOp quantize_op) {
auto quantize_output = quantize_op.output();
auto quantize_type = quantize_output->getType();
input_types.push_back(quantize_type);
auto* new_arg = bb.addArgument(quantize_type);
if (input_op) {
// Make a copy of input op with quantized input and output type.
auto new_input =
builder.create<InputOp>(input_op.getLoc(), quantize_type, new_arg);
quantize_output->replaceAllUsesWith(new_input);
} else {
quantize_output->replaceAllUsesWith(new_arg);
}
quantize_output->replaceAllUsesWith(new_arg);
quantize_op.erase();
if (input_op) input_op.erase();
arg->dropAllUses();
bb.eraseArgument(0);
};
// This is looking for a pattern: arg -> tfl.pseudo_input -> tfl.quantize
// or arg -> tfl.quantize
if (arg->hasOneUse()) {
if (llvm::isa<QuantizeOp>(*arg->user_begin())) {
auto quantize_op = llvm::cast<QuantizeOp>(*arg->user_begin());
remove_quantize_op(quantize_op, /*input_op=*/nullptr);
continue;
} else if (llvm::isa<InputOp>(*arg->user_begin())) {
// TODO(lyandy): Remove arg -> tfl.pseudo_input -> tfl.quantize once
// tfl.pseudo_input are not generated.
auto input_op = llvm::cast<InputOp>(*arg->user_begin());
Value* input_output = input_op.output();
// We can drop the quantization adaptor only when the pseudo input op
// has one user and it is the quantize op. Otherwise, we have to keep
// the adaptor and allow the floating point inputs.
if (input_output->hasOneUse() &&
llvm::isa<QuantizeOp>(*input_output->user_begin())) {
auto quantize_op =
llvm::cast<QuantizeOp>(*input_output->user_begin());
remove_quantize_op(quantize_op, input_op);
continue;
}
}
// This is looking for a pattern: arg -> tfl.quantize
if (arg->hasOneUse() && llvm::isa<QuantizeOp>(*arg->user_begin())) {
auto quantize_op = llvm::cast<QuantizeOp>(*arg->user_begin());
remove_quantize_op(quantize_op);
continue;
}
// Make a copy of current argument and append it to the end of the list if

View File

@ -155,19 +155,9 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
BlockArgument* arg = func.getArgument(i);
if (arg->hasOneUse() && llvm::isa<TFL::InputOp>(*arg->getUsers().begin())) {
// TODO(lyandy): Remove arg -> tfl.pseudo_input -> tfl.quantize once
// tfl.pseudo_input are not generated.
Operation* input = *arg->getUsers().begin();
auto input_op = llvm::cast<TFL::InputOp>(input);
add_quantize_op(input_op.getLoc(), input_op.input()->getType(),
input->getBlock(), ++Block::iterator(input_op),
input_op.output(), i);
} else {
auto* arg_block = arg->getOwner();
add_quantize_op(arg->getLoc(), arg->getType(), arg_block,
std::next(arg_block->begin(), i), arg, i);
}
auto* arg_block = arg->getOwner();
add_quantize_op(arg->getLoc(), arg->getType(), arg_block,
std::next(arg_block->begin(), i), arg, i);
}
return false;

View File

@ -241,33 +241,6 @@ underlying graph, and executes each of the partitioned subgraphs as a function.
}];
}
// In MLIR, the 'tf.Placeholder.input' instruction is used to capture attributes
// of function arguments.
// Note: NoSideEffect trait is not added intentionally to preserve the captured
// attributes even if the input is unused.
def TF_PlaceholderInputOp : TF_Op<"Placeholder.input",
[SameOperandsAndResultType]> {
let summary = "PlaceholderInput op";
let description = [{
Inserts a placeholder for a tensor that will be always fed.
}];
let arguments = (ins
TF_Tensor:$arg,
OptionalAttr<F32Attr>:$min,
OptionalAttr<F32Attr>:$max,
OptionalAttr<TF_IntTypeAttr>:$type
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_PlaceholderOp : TF_Op<"Placeholder", [NoSideEffect]> {
let summary = "Placeholder op";