Add constant folder for tfl.pseudo_const
Also register materializeConstant for TFL dialect. Otherwise, a tfl.pseudo_const with a value of an opaque attribute will be folded into std.constant. PiperOrigin-RevId: 260570836
This commit is contained in:
parent
7cb9426454
commit
7191d4f3b6
@ -775,6 +775,17 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
|
||||
// Return the held attribute value.
|
||||
return value();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -782,5 +793,16 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||
|
||||
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value,
|
||||
Type type, Location loc) {
|
||||
// If this is an opaque elements attribute or the result type doesn't match
|
||||
// the attribute type, then generate a tfl.pseudo_const.
|
||||
if (value.isa<OpaqueElementsAttr>() ||
|
||||
(value.isa<ElementsAttr>() && value.getType() != type))
|
||||
return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
@ -36,6 +36,11 @@ namespace TFL {
|
||||
class TensorFlowLiteDialect : public Dialect {
|
||||
public:
|
||||
explicit TensorFlowLiteDialect(MLIRContext *context);
|
||||
|
||||
// Registered hook to materialize a constant operation from a given attribute
|
||||
// value with the desired resultant type.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
Location loc) override;
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
@ -545,6 +545,8 @@ def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [NoSideEffect,
|
||||
let arguments = (ins ElementsAttr:$value);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution">;
|
||||
|
@ -304,7 +304,6 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: @rank
|
||||
func @rank() -> tensor<1xi32> {
|
||||
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
|
||||
@ -323,4 +322,13 @@ func @reshape() -> tensor<1x2xi32> {
|
||||
// CHECK: return [[cst]]
|
||||
%0 = "tfl.reshape"(%cst) : (tensor<2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @pseudo_const
|
||||
func @pseudo_const() -> tensor<i32> {
|
||||
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: return [[cst]]
|
||||
%0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
|
@ -142,7 +142,7 @@ func @const() -> tensor<2xi32> {
|
||||
return %0: tensor<2xi32>
|
||||
|
||||
// CHECK-LABEL: @const
|
||||
// CHECK: %0 = "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
}
|
||||
|
||||
func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
|
Loading…
Reference in New Issue
Block a user