[MLIR:TF] Add a tf2tf lowering for SqrtGrad
Make GetScalarOfFloatType work for complex types too, so the lowering can also handle them transparently. PiperOrigin-RevId: 355643729 Change-Id: I7859a544f81e3e60050c4f86baea5733d5aa91ec
This commit is contained in:
parent
39772f66fd
commit
5eb029f68e
@ -233,6 +233,18 @@ func @rsqrt_grad_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sqrt_grad_unranked
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xcomplex<f32>>, %[[ARG1:.*]]: tensor<*xcomplex<f32>>)
|
||||
func @sqrt_grad_unranked(%arg0: tensor<*xcomplex<f32>>, %arg1: tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>> {
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<(5.000000e-01,0.000000e+00)> : tensor<complex<f32>>} : () -> tensor<complex<f32>>
|
||||
// CHECK: %[[MUL:.*]] = "tf.Mul"(%arg1, %[[CST]]) : (tensor<*xcomplex<f32>>, tensor<complex<f32>>) -> tensor<*xcomplex<f32>>
|
||||
// CHECK: %[[RET:.*]] = "tf.Div"(%[[MUL]], %arg0) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
|
||||
|
||||
%0 = "tf.SqrtGrad"(%arg0, %arg1) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
|
||||
// CHECK: return %[[RET]]
|
||||
return %0 : tensor<*xcomplex<f32>>
|
||||
}
|
||||
|
||||
// %input has 1 batch dimension then 2 block dimensions then 1 remainder
|
||||
// dimension.
|
||||
// CHECK-LABEL: fourdim_space_to_batch_nd
|
||||
|
@ -69,8 +69,9 @@ static APFloat ConvertToAPFloat(double val, Type type) {
|
||||
}
|
||||
|
||||
// Returns int, float, or complex DenseElementsAttr with scalar shape with the
|
||||
// given element type and the integer value.
|
||||
static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
||||
// given element type and the value.
|
||||
template <typename T>
|
||||
static DenseElementsAttr GetScalarOfType(Type ty, T raw_value) {
|
||||
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||
if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
|
||||
FloatAttr attr = FloatAttr::get(float_ty, raw_value);
|
||||
@ -91,14 +92,6 @@ static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
||||
llvm_unreachable("unsupported type");
|
||||
}
|
||||
|
||||
// Returns float DenseElementsAttr with scalar shape with the specified value.
|
||||
static DenseElementsAttr GetScalarOfFloatType(Type ty, double raw_value) {
|
||||
auto float_ty = ty.cast<FloatType>();
|
||||
FloatAttr attr = FloatAttr::get(float_ty, raw_value);
|
||||
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||
return DenseElementsAttr::get(scalar_ty, attr);
|
||||
}
|
||||
|
||||
// Returns reduction indices to use while lowering tf.BiasAddGrad op to tf.Sum
|
||||
// op.
|
||||
DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank,
|
||||
@ -1412,9 +1405,9 @@ class LowerResizeNearestNeighbor : public RewritePattern {
|
||||
out_w_f32);
|
||||
|
||||
Value zero_f32 = rewriter.create<ConstOp>(
|
||||
loc, GetScalarOfFloatType(rewriter.getF32Type(), 0.0));
|
||||
loc, GetScalarOfType(rewriter.getF32Type(), 0.0));
|
||||
Value one_f32 = rewriter.create<ConstOp>(
|
||||
loc, GetScalarOfFloatType(rewriter.getF32Type(), 1.0));
|
||||
loc, GetScalarOfType(rewriter.getF32Type(), 1.0));
|
||||
|
||||
Value y_range = rewriter.create<RangeOp>(
|
||||
loc,
|
||||
@ -1566,6 +1559,7 @@ void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
|
||||
LowerSizeOp,
|
||||
LowerSoftmaxCrossEntropyWithLogitsOp,
|
||||
LowerSparseSoftmaxCrossEntropyWithLogitsOp,
|
||||
LowerSqrtGradOp,
|
||||
LowerSquareOp,
|
||||
LowerSquaredDifferenceOpOnRealTensors,
|
||||
LowerSquaredDifferenceOpOneComplexTensors,
|
||||
|
@ -22,14 +22,14 @@ class GetScalarOfType<int value> : NativeCodeCall<
|
||||
"GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||
|
||||
class GetScalarOfFloatType<string value> : NativeCodeCall<
|
||||
"GetScalarOfFloatType(getElementTypeOrSelf($0)," # value # ")">;
|
||||
"GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||
|
||||
def GetScalarInfOfType : NativeCodeCall<
|
||||
"GetScalarOfFloatType(getElementTypeOrSelf($0), "
|
||||
"GetScalarOfType(getElementTypeOrSelf($0), "
|
||||
"std::numeric_limits<double>::infinity())">;
|
||||
|
||||
def GetScalarNanOfType : NativeCodeCall<
|
||||
"GetScalarOfFloatType(getElementTypeOrSelf($0), "
|
||||
"GetScalarOfType(getElementTypeOrSelf($0), "
|
||||
"std::numeric_limits<double>::quiet_NaN())">;
|
||||
|
||||
class GetI64ScalarElementsAttr<int value> :
|
||||
@ -320,6 +320,18 @@ def LowerSizeOp : Pat<
|
||||
ConstBoolAttrFalse
|
||||
)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sqrt op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// SqrtGrad(y, dy) = dy * 0.5 / y
|
||||
def LowerSqrtGradOp : Pat<
|
||||
(TF_SqrtGradOp $y, $dy),
|
||||
(TF_DivOp
|
||||
(TF_MulOp $dy, (TF_ConstOp (GetScalarOfFloatType<"0.5"> $dy))),
|
||||
$y
|
||||
)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TanhGrad op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -232,7 +232,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::SpaceToBatchOp>(),
|
||||
TypeID::get<TF::SpaceToDepthOp>(),
|
||||
TypeID::get<TF::SparseToDenseOp>(),
|
||||
TypeID::get<TF::SqrtGradOp>(),
|
||||
TypeID::get<TF::SquareOp>(),
|
||||
TypeID::get<TF::StatelessMultinomialOp>(),
|
||||
TypeID::get<TF::StatelessRandomGetAlgOp>(),
|
||||
|
Loading…
Reference in New Issue
Block a user