Add constant folder for tfl.pseudo_const
Also register materializeConstant for TFL dialect. Otherwise, a tfl.pseudo_const with a value of an opaque attribute will be folded into std.constant. PiperOrigin-RevId: 260570836
This commit is contained in:
parent
7cb9426454
commit
7191d4f3b6
@ -775,6 +775,17 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
assert(operands.empty() && "constant has no operands");
|
||||||
|
|
||||||
|
// Return the held attribute value.
|
||||||
|
return value();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -782,5 +793,16 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||||
|
|
||||||
|
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
|
||||||
|
Attribute value,
|
||||||
|
Type type, Location loc) {
|
||||||
|
// If this is an opaque elements attribute or the result type doesn't match
|
||||||
|
// the attribute type, then generate a tfl.pseudo_const.
|
||||||
|
if (value.isa<OpaqueElementsAttr>() ||
|
||||||
|
(value.isa<ElementsAttr>() && value.getType() != type))
|
||||||
|
return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -36,6 +36,11 @@ namespace TFL {
|
|||||||
class TensorFlowLiteDialect : public Dialect {
|
class TensorFlowLiteDialect : public Dialect {
|
||||||
public:
|
public:
|
||||||
explicit TensorFlowLiteDialect(MLIRContext *context);
|
explicit TensorFlowLiteDialect(MLIRContext *context);
|
||||||
|
|
||||||
|
// Registered hook to materialize a constant operation from a given attribute
|
||||||
|
// value with the desired resultant type.
|
||||||
|
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||||
|
Location loc) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
|
@ -545,6 +545,8 @@ def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [NoSideEffect,
|
|||||||
let arguments = (ins ElementsAttr:$value);
|
let arguments = (ins ElementsAttr:$value);
|
||||||
|
|
||||||
let results = (outs AnyTensor:$output);
|
let results = (outs AnyTensor:$output);
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution">;
|
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution">;
|
||||||
|
@ -304,7 +304,6 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
|
|||||||
// CHECK: return %0
|
// CHECK: return %0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: @rank
|
// CHECK-LABEL: @rank
|
||||||
func @rank() -> tensor<1xi32> {
|
func @rank() -> tensor<1xi32> {
|
||||||
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
|
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
|
||||||
@ -324,3 +323,12 @@ func @reshape() -> tensor<1x2xi32> {
|
|||||||
%0 = "tfl.reshape"(%cst) : (tensor<2xi32>) -> tensor<1x2xi32>
|
%0 = "tfl.reshape"(%cst) : (tensor<2xi32>) -> tensor<1x2xi32>
|
||||||
return %0 : tensor<1x2xi32>
|
return %0 : tensor<1x2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @pseudo_const
|
||||||
|
func @pseudo_const() -> tensor<i32> {
|
||||||
|
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32>
|
||||||
|
// CHECK: return [[cst]]
|
||||||
|
%0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ func @const() -> tensor<2xi32> {
|
|||||||
return %0: tensor<2xi32>
|
return %0: tensor<2xi32>
|
||||||
|
|
||||||
// CHECK-LABEL: @const
|
// CHECK-LABEL: @const
|
||||||
// CHECK: %0 = "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
|
// CHECK: "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
|
func @placeholder(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
Loading…
Reference in New Issue
Block a user