diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 861048a9a93..d593c0ec836 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -306,6 +306,14 @@ inline bool IsF32ShapedType(Type t) { return false; } +// Returns true if it is a shaped type of bf16 elements. +inline bool IsBF16ShapedType(Type t) { + if (auto shaped_type = t.dyn_cast_or_null()) { + return shaped_type.getElementType().isBF16(); + } + return false; +} + // Performs const folding `calculate` with broadcast behavior on the two // attributes `operand1` and `operand2` and returns the result if possible. // The two operands are expected to both be scalar values. @@ -498,7 +506,7 @@ Attribute ConstFoldBinaryOp( /// "tfl.logical_not". Attribute ConstFoldUnaryOp(Type result_type, Attribute operand, llvm::function_ref calculate) { - assert(IsF32ShapedType(result_type)); + assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type)); auto result_shape_type = result_type.cast(); if (auto dense_elements = operand.dyn_cast_or_null()) { @@ -1911,13 +1919,20 @@ OpFoldResult SqrtOp::fold(ArrayRef operands) { OpFoldResult RsqrtOp::fold(ArrayRef operands) { Type result_type = getType(); - // Only constant fold for tensor of f32 is implemented. - if (!IsF32ShapedType(result_type)) return nullptr; + // Only constant fold for tensor of f32/bf16 is implemented. + if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type)) + return nullptr; auto compute = [](APFloat value) -> APFloat { + bool loseInfo; + const llvm::fltSemantics &original_float_semantics = value.getSemantics(); + value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, + &loseInfo); float f = value.convertToFloat(); - float result = 1.f / std::sqrt(f); - return APFloat(result); + APFloat result(1.f / std::sqrt(f)); + result.convert(original_float_semantics, APFloat::rmNearestTiesToEven, + &loseInfo); + return result; }; return ConstFoldUnaryOp(result_type, operands[0], compute); } diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index ff7c47fb621..69009ae594b 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -577,3 +577,13 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> { // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32> // CHECK: return %[[CST]] } + +// CHECK-LABEL: @rsqrt_bf16 +func @rsqrt_bf16() -> tensor { + %cst = constant dense<4.0> : tensor + %0 = "tfl.rsqrt"(%cst) : (tensor) -> tensor + return %0 : tensor + +// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor +// CHECK: return %[[CST]] +}