Add shape inference function for tf.XlaSetDynamicDimensionSize and set to not allow constant folding.

PiperOrigin-RevId: 348474369
Change-Id: I9214c5b6acf82ed82acba8b72c3ae665c6975770
This commit is contained in:
Andy Ly 2020-12-21 08:52:00 -08:00 committed by TensorFlower Gardener
parent 2ccbbdb4b0
commit 0f99ddcfe3
2 changed files with 13 additions and 1 deletions

View File

@ -16604,7 +16604,7 @@ key: A unique identifier for this region used to match up host transfers.
TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>;
}
def TF_XlaSetDynamicDimensionSizeOp : TF_Op<"XlaSetDynamicDimensionSize", [NoSideEffect]> {
def TF_XlaSetDynamicDimensionSizeOp : TF_Op<"XlaSetDynamicDimensionSize", [DeclareOpInterfaceMethods<InferTypeOpInterface>, NoSideEffect, TF_NoConstantFold]> {
let summary = "Make a static dimension into a xla bounded dynamic dimension.";
let description = [{

View File

@ -2969,6 +2969,18 @@ void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<XdivyWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
// XlaSetDynamicDimensionSizeOp
//===----------------------------------------------------------------------===//
LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({operands.front().getType()});
return success();
}
} // namespace TF
} // namespace mlir