[TFLite/MLIR] Supports bf16 constant folding for rsqrt op.

PiperOrigin-RevId: 339784984
Change-Id: I3d17fecafa3f994256c6ff87df8e346442f3bb28
This commit is contained in:
A. Unique TensorFlower 2020-10-29 18:04:11 -07:00 committed by TensorFlower Gardener
parent a4a4fc8c04
commit 4a1156a49e
2 changed files with 30 additions and 5 deletions

View File

@ -306,6 +306,14 @@ inline bool IsF32ShapedType(Type t) {
return false;
}
// Returns true if it is a shaped type of bf16 elements.
inline bool IsBF16ShapedType(Type t) {
if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
return shaped_type.getElementType().isBF16();
}
return false;
}
// Performs const folding `calculate` with broadcast behavior on the two
// attributes `operand1` and `operand2` and returns the result if possible.
// The two operands are expected to both be scalar values.
@ -498,7 +506,7 @@ Attribute ConstFoldBinaryOp(
/// "tfl.logical_not".
Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
llvm::function_ref<APFloat(APFloat)> calculate) {
assert(IsF32ShapedType(result_type));
assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
auto result_shape_type = result_type.cast<ShapedType>();
if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
@ -1911,13 +1919,20 @@ OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
// Only constant fold for tensor of f32/bf16 is implemented.
if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
return nullptr;
auto compute = [](APFloat value) -> APFloat {
bool loseInfo;
const llvm::fltSemantics &original_float_semantics = value.getSemantics();
value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
&loseInfo);
float f = value.convertToFloat();
float result = 1.f / std::sqrt(f);
return APFloat(result);
APFloat result(1.f / std::sqrt(f));
result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
&loseInfo);
return result;
};
return ConstFoldUnaryOp(result_type, operands[0], compute);
}

View File

@ -577,3 +577,13 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @rsqrt_bf16
func @rsqrt_bf16() -> tensor<bf16> {
%cst = constant dense<4.0> : tensor<bf16>
%0 = "tfl.rsqrt"(%cst) : (tensor<bf16>) -> tensor<bf16>
return %0 : tensor<bf16>
// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
// CHECK: return %[[CST]]
}