[MLIR:TF] Add a tf2tf lowering for SqrtGrad

Make GetScalarOfFloatType work for complex types too, so the lowering can also
handle them transparently.

PiperOrigin-RevId: 355643729
Change-Id: I7859a544f81e3e60050c4f86baea5733d5aa91ec
This commit is contained in:
Benjamin Kramer 2021-02-04 09:29:03 -08:00 committed by TensorFlower Gardener
parent 39772f66fd
commit 5eb029f68e
4 changed files with 33 additions and 16 deletions

View File

@ -233,6 +233,18 @@ func @rsqrt_grad_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @sqrt_grad_unranked
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xcomplex<f32>>, %[[ARG1:.*]]: tensor<*xcomplex<f32>>)
func @sqrt_grad_unranked(%arg0: tensor<*xcomplex<f32>>, %arg1: tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>> {
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<(5.000000e-01,0.000000e+00)> : tensor<complex<f32>>} : () -> tensor<complex<f32>>
// CHECK: %[[MUL:.*]] = "tf.Mul"(%arg1, %[[CST]]) : (tensor<*xcomplex<f32>>, tensor<complex<f32>>) -> tensor<*xcomplex<f32>>
// CHECK: %[[RET:.*]] = "tf.Div"(%[[MUL]], %arg0) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
%0 = "tf.SqrtGrad"(%arg0, %arg1) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
// CHECK: return %[[RET]]
return %0 : tensor<*xcomplex<f32>>
}
// %input has 1 batch dimension then 2 block dimensions then 1 remainder
// dimension.
// CHECK-LABEL: fourdim_space_to_batch_nd

View File

@ -69,8 +69,9 @@ static APFloat ConvertToAPFloat(double val, Type type) {
}
// Returns int, float, or complex DenseElementsAttr with scalar shape with the
// given element type and the integer value.
static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
// given element type and the value.
template <typename T>
static DenseElementsAttr GetScalarOfType(Type ty, T raw_value) {
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
FloatAttr attr = FloatAttr::get(float_ty, raw_value);
@ -91,14 +92,6 @@ static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
llvm_unreachable("unsupported type");
}
// Returns float DenseElementsAttr with scalar shape with the specified value.
static DenseElementsAttr GetScalarOfFloatType(Type ty, double raw_value) {
auto float_ty = ty.cast<FloatType>();
FloatAttr attr = FloatAttr::get(float_ty, raw_value);
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
return DenseElementsAttr::get(scalar_ty, attr);
}
// Returns reduction indices to use while lowering tf.BiasAddGrad op to tf.Sum
// op.
DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank,
@ -1412,9 +1405,9 @@ class LowerResizeNearestNeighbor : public RewritePattern {
out_w_f32);
Value zero_f32 = rewriter.create<ConstOp>(
loc, GetScalarOfFloatType(rewriter.getF32Type(), 0.0));
loc, GetScalarOfType(rewriter.getF32Type(), 0.0));
Value one_f32 = rewriter.create<ConstOp>(
loc, GetScalarOfFloatType(rewriter.getF32Type(), 1.0));
loc, GetScalarOfType(rewriter.getF32Type(), 1.0));
Value y_range = rewriter.create<RangeOp>(
loc,
@ -1566,6 +1559,7 @@ void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
LowerSizeOp,
LowerSoftmaxCrossEntropyWithLogitsOp,
LowerSparseSoftmaxCrossEntropyWithLogitsOp,
LowerSqrtGradOp,
LowerSquareOp,
LowerSquaredDifferenceOpOnRealTensors,
LowerSquaredDifferenceOpOneComplexTensors,

View File

@ -22,14 +22,14 @@ class GetScalarOfType<int value> : NativeCodeCall<
"GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
class GetScalarOfFloatType<string value> : NativeCodeCall<
"GetScalarOfFloatType(getElementTypeOrSelf($0)," # value # ")">;
"GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
def GetScalarInfOfType : NativeCodeCall<
"GetScalarOfFloatType(getElementTypeOrSelf($0), "
"GetScalarOfType(getElementTypeOrSelf($0), "
"std::numeric_limits<double>::infinity())">;
def GetScalarNanOfType : NativeCodeCall<
"GetScalarOfFloatType(getElementTypeOrSelf($0), "
"GetScalarOfType(getElementTypeOrSelf($0), "
"std::numeric_limits<double>::quiet_NaN())">;
class GetI64ScalarElementsAttr<int value> :
@ -320,6 +320,18 @@ def LowerSizeOp : Pat<
ConstBoolAttrFalse
)>;
//===----------------------------------------------------------------------===//
// Sqrt op patterns.
//===----------------------------------------------------------------------===//
// SqrtGrad(y, dy) = dy * 0.5 / y
def LowerSqrtGradOp : Pat<
(TF_SqrtGradOp $y, $dy),
(TF_DivOp
(TF_MulOp $dy, (TF_ConstOp (GetScalarOfFloatType<"0.5"> $dy))),
$y
)>;
//===----------------------------------------------------------------------===//
// TanhGrad op patterns.
//===----------------------------------------------------------------------===//

View File

@ -232,7 +232,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::SpaceToBatchOp>(),
TypeID::get<TF::SpaceToDepthOp>(),
TypeID::get<TF::SparseToDenseOp>(),
TypeID::get<TF::SqrtGradOp>(),
TypeID::get<TF::SquareOp>(),
TypeID::get<TF::StatelessMultinomialOp>(),
TypeID::get<TF::StatelessRandomGetAlgOp>(),