[TFLite/MLIR] Supports bf16 constant folding for rsqrt op.
PiperOrigin-RevId: 339784984 Change-Id: I3d17fecafa3f994256c6ff87df8e346442f3bb28
This commit is contained in:
parent
a4a4fc8c04
commit
4a1156a49e
@ -306,6 +306,14 @@ inline bool IsF32ShapedType(Type t) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns true if it is a shaped type of bf16 elements.
|
||||
inline bool IsBF16ShapedType(Type t) {
|
||||
if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
|
||||
return shaped_type.getElementType().isBF16();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Performs const folding `calculate` with broadcast behavior on the two
|
||||
// attributes `operand1` and `operand2` and returns the result if possible.
|
||||
// The two operands are expected to both be scalar values.
|
||||
@ -498,7 +506,7 @@ Attribute ConstFoldBinaryOp(
|
||||
/// "tfl.logical_not".
|
||||
Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
|
||||
llvm::function_ref<APFloat(APFloat)> calculate) {
|
||||
assert(IsF32ShapedType(result_type));
|
||||
assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
|
||||
auto result_shape_type = result_type.cast<ShapedType>();
|
||||
|
||||
if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
@ -1911,13 +1919,20 @@ OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
|
||||
Type result_type = getType();
|
||||
// Only constant fold for tensor of f32 is implemented.
|
||||
if (!IsF32ShapedType(result_type)) return nullptr;
|
||||
// Only constant fold for tensor of f32/bf16 is implemented.
|
||||
if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
|
||||
return nullptr;
|
||||
|
||||
auto compute = [](APFloat value) -> APFloat {
|
||||
bool loseInfo;
|
||||
const llvm::fltSemantics &original_float_semantics = value.getSemantics();
|
||||
value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
|
||||
&loseInfo);
|
||||
float f = value.convertToFloat();
|
||||
float result = 1.f / std::sqrt(f);
|
||||
return APFloat(result);
|
||||
APFloat result(1.f / std::sqrt(f));
|
||||
result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
|
||||
&loseInfo);
|
||||
return result;
|
||||
};
|
||||
return ConstFoldUnaryOp(result_type, operands[0], compute);
|
||||
}
|
||||
|
@ -577,3 +577,13 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rsqrt_bf16
|
||||
func @rsqrt_bf16() -> tensor<bf16> {
|
||||
%cst = constant dense<4.0> : tensor<bf16>
|
||||
%0 = "tfl.rsqrt"(%cst) : (tensor<bf16>) -> tensor<bf16>
|
||||
return %0 : tensor<bf16>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user