Add constant folder for tfl.pseudo_const

Also register materializeConstant for TFL dialect. Otherwise, a tfl.pseudo_const with a value of an opaque attribute will be folded into std.constant.

PiperOrigin-RevId: 260570836
This commit is contained in:
Jing Pu 2019-07-29 13:42:55 -07:00 committed by TensorFlower Gardener
parent 7cb9426454
commit 7191d4f3b6
5 changed files with 40 additions and 3 deletions

View File

@ -775,6 +775,17 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// ConstOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
// Return the held attribute value.
return value();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
@ -782,5 +793,16 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
Attribute value,
Type type, Location loc) {
// If this is an opaque elements attribute or the result type doesn't match
// the attribute type, then generate a tfl.pseudo_const.
if (value.isa<OpaqueElementsAttr>() ||
(value.isa<ElementsAttr>() && value.getType() != type))
return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
return nullptr;
}
} // namespace TFL
} // namespace mlir

View File

@ -36,6 +36,11 @@ namespace TFL {
class TensorFlowLiteDialect : public Dialect {
public:
explicit TensorFlowLiteDialect(MLIRContext *context);
// Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
};
#define GET_OP_CLASSES

View File

@ -545,6 +545,8 @@ def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [NoSideEffect,
let arguments = (ins ElementsAttr:$value);
let results = (outs AnyTensor:$output);
let hasFolder = 1;
}
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution">;

View File

@ -304,7 +304,6 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
// CHECK: return %0
}
// CHECK-LABEL: @rank
func @rank() -> tensor<1xi32> {
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
@ -323,4 +322,13 @@ func @reshape() -> tensor<1x2xi32> {
// CHECK: return [[cst]]
%0 = "tfl.reshape"(%cst) : (tensor<2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
}
// CHECK-LABEL: @pseudo_const
func @pseudo_const() -> tensor<i32> {
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32>
// CHECK: return [[cst]]
%0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
return %0 : tensor<i32>
}

View File

@ -142,7 +142,7 @@ func @const() -> tensor<2xi32> {
return %0: tensor<2xi32>
// CHECK-LABEL: @const
// CHECK: %0 = "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
}
func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {