[MLIR][CHLO] Use CHLO lowering for is_inf op

PiperOrigin-RevId: 355189054
Change-Id: I28304ff8ed9f564a9698fb5609c19d5d19956e86
This commit is contained in:
A. Unique TensorFlower 2021-02-02 09:52:11 -08:00 committed by TensorFlower Gardener
parent e63a25656b
commit 22bf3df6dc
3 changed files with 3 additions and 3 deletions

View File

@ -54,8 +54,8 @@ namespace {
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \
sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(DigammaOp) \
sep fn(ErfOp) sep fn(ErfcOp) sep fn(LgammaOp) sep fn(SinhOp) \
sep fn(TanOp)
sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \
sep fn(SinhOp) sep fn(TanOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {

View File

@ -1552,7 +1552,6 @@ void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
LowerExpm1Op,
LowerFakeQuantWithMinMaxArgs,
LowerFillOp,
LowerIsInfOp,
LowerIsNanOp,
LowerL2LossOp,
LowerMulNoNanOp,

View File

@ -603,6 +603,7 @@ foreach Mapping = [
[TF_ImagOp, HLO_ImagOp],
[TF_InvertOp, HLO_NotOp],
[TF_IsFiniteOp, HLO_IsFiniteOp],
[TF_IsInfOp, HLOClient_IsInfOp],
[TF_LgammaOp, HLOClient_LgammaOp],
[TF_LogOp, HLO_LogOp],
[TF_Log1pOp, HLO_Log1pOp],