Integrate LLVM at llvm/llvm-project@202766947e
Updates LLVM usage to match [202766947edb](https://github.com/llvm/llvm-project/commit/202766947edb) PiperOrigin-RevId: 329673065 Change-Id: I349e89c8322e8cec75d9ddc5aa2c1b7093ffeac2
This commit is contained in:
parent
d7d06e423b
commit
da2c5d2647
@ -172,7 +172,7 @@ static LogicalResult Verify(DotGeneralOp op) {
|
||||
/// Fold get_dimension_size when the said shape dimension is a constant.
|
||||
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
|
||||
RankedTensorType type = operand().getType().cast<RankedTensorType>();
|
||||
int32_t dim = dimension().getSExtValue();
|
||||
int32_t dim = dimension();
|
||||
if (type.isDynamic(dim)) return {};
|
||||
// The result type is always is a 0-d i32 tensor.
|
||||
return DenseIntElementsAttr::get<int32_t>(
|
||||
@ -190,7 +190,7 @@ static LogicalResult Verify(IotaOp op) {
|
||||
if (shape.getRank() == 0)
|
||||
return op.emitOpError() << "does not support scalars.";
|
||||
|
||||
auto iota_dimension = op.iota_dimension().getSExtValue();
|
||||
auto iota_dimension = op.iota_dimension();
|
||||
if (iota_dimension >= shape.getRank() || iota_dimension < 0)
|
||||
return op.emitOpError() << "iota dimension cannot go beyond the output "
|
||||
"rank or be negative.";
|
||||
@ -212,8 +212,7 @@ struct IotaBroadcast : public OpRewritePattern<IotaOp> {
|
||||
auto iota_dimension = iota.iota_dimension();
|
||||
|
||||
auto iota_type = RankedTensorType::get(
|
||||
{result_ty.getDimSize(iota_dimension.getLimitedValue())},
|
||||
result_ty.getElementType());
|
||||
{result_ty.getDimSize(iota_dimension)}, result_ty.getElementType());
|
||||
|
||||
auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
@ -233,7 +232,7 @@ void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
}
|
||||
|
||||
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto dimension = iota_dimension().getLimitedValue();
|
||||
auto dimension = iota_dimension();
|
||||
auto result_ty = getResult().getType().cast<ShapedType>();
|
||||
if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
|
||||
Builder builder(getContext());
|
||||
@ -277,7 +276,7 @@ struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
|
||||
}
|
||||
|
||||
auto iota_dimension = iota.iota_dimension();
|
||||
auto iota_dimension_int = iota_dimension.getLimitedValue();
|
||||
auto iota_dimension_int = iota_dimension;
|
||||
|
||||
auto converted_shape = rewriter.create<IndexCastOp>(
|
||||
iota.getLoc(),
|
||||
@ -476,7 +475,7 @@ static LogicalResult Verify(DequantizeOp op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(GetTupleElementOp op) {
|
||||
auto indexVal = op.index().getZExtValue();
|
||||
auto indexVal = op.index();
|
||||
auto operandType = op.getOperand().getType().cast<TupleType>();
|
||||
if (indexVal >= operandType.size()) {
|
||||
return op.emitOpError(
|
||||
@ -495,7 +494,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
|
||||
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto tupleOp =
|
||||
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
|
||||
return tupleOp.getOperand(index().getLimitedValue());
|
||||
return tupleOp.getOperand(index());
|
||||
}
|
||||
|
||||
return {};
|
||||
@ -565,8 +564,8 @@ static LogicalResult Verify(AllToAllOp op) {
|
||||
// count.
|
||||
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
if (!type) return success();
|
||||
auto split_dim_size = type.getDimSize(op.split_dimension().getSExtValue());
|
||||
auto split_count = op.split_count().getSExtValue();
|
||||
auto split_dim_size = type.getDimSize(op.split_dimension());
|
||||
auto split_count = op.split_count();
|
||||
if (split_dim_size % split_count != 0) {
|
||||
return op.emitError() << "split dimension has size " << split_dim_size
|
||||
<< ", expected to be a multiple of split_count "
|
||||
@ -914,7 +913,7 @@ class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ConcatenateOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto axis = op.dimension().getLimitedValue();
|
||||
auto axis = op.dimension();
|
||||
llvm::SmallVector<Value, 6> new_operands;
|
||||
for (auto operand : op.getOperands()) {
|
||||
auto ty = operand.getType().cast<ShapedType>();
|
||||
@ -994,7 +993,7 @@ void ConcatenateOp::getCanonicalizationPatterns(
|
||||
template <typename T>
|
||||
static Attribute foldConcatenateHelper(ConcatenateOp* op,
|
||||
ArrayRef<Attribute> operands) {
|
||||
auto axis = op->dimension().getLimitedValue();
|
||||
auto axis = op->dimension();
|
||||
auto type = op->getType().cast<ShapedType>();
|
||||
|
||||
SmallVector<T, 6> values;
|
||||
@ -1042,7 +1041,7 @@ OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
|
||||
ShapedType type = getResult().getType().cast<ShapedType>();
|
||||
if (!type.hasStaticShape()) return {};
|
||||
|
||||
auto axis = dimension().getLimitedValue();
|
||||
auto axis = dimension();
|
||||
if (auto attr = foldConcatenate(this, operands)) {
|
||||
return attr;
|
||||
}
|
||||
@ -1845,7 +1844,7 @@ struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto dimension = concat.dimension().getSExtValue();
|
||||
auto dimension = concat.dimension();
|
||||
|
||||
auto start = slice.start_indices().getIntValues();
|
||||
auto limit = slice.limit_indices().getIntValues();
|
||||
@ -1995,7 +1994,7 @@ static LogicalResult Verify(SortOp op) {
|
||||
return op.emitOpError("requires all inputs to have the same dimensions");
|
||||
|
||||
int64_t rank = input_shape.size();
|
||||
int64_t cmp_dim = op.dimension().getSExtValue();
|
||||
int64_t cmp_dim = op.dimension();
|
||||
if (cmp_dim < -rank || cmp_dim >= rank)
|
||||
return op.emitOpError("dimension attribute value must be in range [-")
|
||||
<< rank << ", " << rank << "), but found " << cmp_dim;
|
||||
|
@ -704,7 +704,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
|
||||
ValueRange args) {
|
||||
Value castOp = nestedBuilder.create<IndexCastOp>(
|
||||
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
|
||||
nestedLoc, ivs[iotaOp.iota_dimension()],
|
||||
nestedBuilder.getIntegerType(
|
||||
resultElementType.getIntOrFloatBitWidth()));
|
||||
if (resultElementType.template isa<FloatType>()) {
|
||||
|
@ -117,7 +117,7 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_size = output_type.getNumElements();
|
||||
auto dimension = op.iota_dimension().getSExtValue();
|
||||
auto dimension = op.iota_dimension();
|
||||
auto max_dim_size = output_type.getDimSize(dimension);
|
||||
|
||||
auto element_type = output_type.getElementType();
|
||||
|
@ -80,7 +80,7 @@ void MatchAndRewrite(WhileOp whileOp) {
|
||||
// the external value is captured.
|
||||
if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
|
||||
if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
|
||||
int index = gte.index().getSExtValue();
|
||||
int index = gte.index();
|
||||
return {tupleOp.getOperand(index), index};
|
||||
}
|
||||
return {nullptr, 0};
|
||||
@ -154,7 +154,7 @@ void MatchAndRewrite(WhileOp whileOp) {
|
||||
use->erase();
|
||||
continue;
|
||||
}
|
||||
int index = gte.index().getSExtValue();
|
||||
int index = gte.index();
|
||||
// If after the loop induction variable, then decrement as we don't include
|
||||
// the loop induction variable in the for iter operands.
|
||||
if (index > loopIndVar.second) --index;
|
||||
|
@ -122,7 +122,7 @@ class UnfuseBatchNormInferencePattern
|
||||
if (!fp_type) {
|
||||
return failure();
|
||||
}
|
||||
int64_t feature_dim = bn_op.feature_index().getSExtValue();
|
||||
int64_t feature_dim = bn_op.feature_index();
|
||||
|
||||
// Add epsilon to the variance and sqrt to get stddev:
|
||||
// stddev = sqrt(variance + epsilon)
|
||||
|
@ -127,12 +127,12 @@ static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
|
||||
|
||||
// I32Attr already returns an int as required by flatbuffer builders.
|
||||
static int ConvertI32AttrForOptionWriter(
|
||||
llvm::APInt i, flatbuffers::FlatBufferBuilder* builder) {
|
||||
return i.getSExtValue();
|
||||
int i, flatbuffers::FlatBufferBuilder* builder) {
|
||||
return i;
|
||||
}
|
||||
|
||||
static int ConvertPositiveI32AttrForOptionWriter(
|
||||
llvm::APInt i, flatbuffers::FlatBufferBuilder* builder) {
|
||||
int i, flatbuffers::FlatBufferBuilder* builder) {
|
||||
return ConvertI32AttrForOptionWriter(i, builder);
|
||||
}
|
||||
|
||||
|
@ -569,7 +569,7 @@ namespace {
|
||||
|
||||
int64_t GetConcatenationOpAxis(ConcatenationOp op) {
|
||||
auto output_type = op.output().getType().cast<RankedTensorType>();
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
int32_t axis = op.axis();
|
||||
if (axis < 0) axis += output_type.getRank();
|
||||
return axis;
|
||||
}
|
||||
@ -1027,13 +1027,13 @@ static LogicalResult Verify(PackOp op) {
|
||||
|
||||
// Check axis bounds.
|
||||
if (input_type.hasRank()) {
|
||||
int64_t axis_value = op.axis().getSExtValue();
|
||||
int32_t axis_value = op.axis();
|
||||
if (axis_value < 0) axis_value += input_type.getRank() + 1;
|
||||
if (axis_value < 0 || axis_value >= input_type.getRank() + 1)
|
||||
return op.emitOpError()
|
||||
<< "op attribute 'axis' should be in range [-rank - 1, rank + 1), "
|
||||
<< "got rank = " << input_type.getRank()
|
||||
<< ", and axis = " << op.axis().getSExtValue();
|
||||
<< ", and axis = " << op.axis();
|
||||
}
|
||||
|
||||
// Make sure all inputs have the same shape and element type.
|
||||
@ -1545,7 +1545,7 @@ static LogicalResult VerifySplitOpOutputTypes(
|
||||
}
|
||||
|
||||
static LogicalResult Verify(SplitOp op) {
|
||||
int64_t num_splits = op.num_splits().getSExtValue();
|
||||
int64_t num_splits = op.num_splits();
|
||||
if (op.getNumResults() != num_splits)
|
||||
return op.emitOpError("output count should match 'num_splits' attribute");
|
||||
|
||||
@ -1581,7 +1581,7 @@ static LogicalResult Verify(SplitOp op) {
|
||||
}
|
||||
|
||||
static LogicalResult Verify(SplitVOp op) {
|
||||
int64_t num_splits = op.num_splits().getSExtValue();
|
||||
int64_t num_splits = op.num_splits();
|
||||
if (op.getNumResults() != num_splits)
|
||||
return op.emitOpError("output count should match 'num_splits' attribute");
|
||||
|
||||
|
@ -106,9 +106,9 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
mins.push_back(FloatAttr::getValueAsDouble(*it++));
|
||||
maxs.push_back(FloatAttr::getValueAsDouble(*it));
|
||||
}
|
||||
quant_type = quant::fakeQuantAttrsToType(
|
||||
op.getLoc(), num_bits, op.axis()->getSExtValue(), mins, maxs,
|
||||
narrow_range, expressed, is_signed);
|
||||
quant_type =
|
||||
quant::fakeQuantAttrsToType(op.getLoc(), num_bits, *op.axis(), mins,
|
||||
maxs, narrow_range, expressed, is_signed);
|
||||
} else if (auto stats = op.layerStats().dyn_cast<DenseFPElementsAttr>()) {
|
||||
double rmin = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
|
||||
double rmax = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
|
||||
|
@ -107,8 +107,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp
|
||||
// Use the min/max from the operands and the num_bits and narrow_range
|
||||
// attribute to create the quantization parameter for the new quantize op.
|
||||
rewriter.setInsertionPointAfter(tf_op);
|
||||
IntegerAttr num_bits =
|
||||
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
|
||||
IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits());
|
||||
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
|
||||
Type res_type = tf_op.getType();
|
||||
TypeAttr qtype = quant::GetQuantizedTypeAttr(
|
||||
|
@ -158,9 +158,8 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
|
||||
tensorflow::random::PhiloxRandom, float>
|
||||
Distribution;
|
||||
|
||||
tensorflow::random::PhiloxRandom generator(
|
||||
random_uniform_op.seed().getSExtValue(),
|
||||
random_uniform_op.seed2().getSExtValue());
|
||||
tensorflow::random::PhiloxRandom generator(random_uniform_op.seed(),
|
||||
random_uniform_op.seed2());
|
||||
Distribution dist;
|
||||
size_t num_elements = 0;
|
||||
if (auto output_type =
|
||||
@ -284,7 +283,7 @@ LogicalResult ConvertTFPackOp::matchAndRewrite(
|
||||
auto output_type = tf_pack_op.output().getType();
|
||||
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
|
||||
// Axis can be negative.
|
||||
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue());
|
||||
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis());
|
||||
|
||||
rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count,
|
||||
axis);
|
||||
@ -381,27 +380,22 @@ LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
||||
tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
|
||||
tf_strided_slice_op.strides(),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.begin_mask().getSExtValue()),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.end_mask().getSExtValue()),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.ellipsis_mask().getSExtValue()),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.new_axis_mask().getSExtValue()),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.shrink_axis_mask().getSExtValue()));
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.begin_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.end_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
|
||||
return success();
|
||||
}
|
||||
|
||||
int num_input_dims = ranked_input_type.getRank();
|
||||
// Pad `begin` array with zero values and update the `begin_mask`.
|
||||
SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0);
|
||||
int begin_mask = tf_strided_slice_op.begin_mask().getSExtValue();
|
||||
int begin_mask = tf_strided_slice_op.begin_mask();
|
||||
Value padded_begin = PadStridedSliceAttributeArray(
|
||||
op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask);
|
||||
// Pad `end` array with `input_shape` and update the `end_mask`.
|
||||
int end_mask = tf_strided_slice_op.end_mask().getSExtValue();
|
||||
int end_mask = tf_strided_slice_op.end_mask();
|
||||
auto input_shape = ranked_input_type.getShape();
|
||||
SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end());
|
||||
Value padded_end = PadStridedSliceAttributeArray(
|
||||
@ -415,12 +409,9 @@ LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
padded_begin, padded_end, padded_strides,
|
||||
rewriter.getI32IntegerAttr(begin_mask),
|
||||
rewriter.getI32IntegerAttr(end_mask),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.ellipsis_mask().getSExtValue()),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.new_axis_mask().getSExtValue()),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.shrink_axis_mask().getSExtValue()));
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -431,7 +422,7 @@ LogicalResult ConvertTFUnpackOp::matchAndRewrite(
|
||||
auto input = tf_unpack_op.value();
|
||||
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
|
||||
// Axis can be negative.
|
||||
auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue());
|
||||
auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis());
|
||||
|
||||
rewriter.replaceOpWithNewOp<UnpackOp>(op, tf_unpack_op.output().getTypes(),
|
||||
input, num, axis);
|
||||
|
@ -714,7 +714,7 @@ struct ConvertTensorListStack
|
||||
RankedTensorType shape_type =
|
||||
RankedTensorType::get({-1}, rewriter.getIntegerType(32));
|
||||
auto new_shape = rewriter.create<TF::ShapeOp>(loc, shape_type, input);
|
||||
SmallVector<int64_t, 8> output_shape = {op.num_elements().getSExtValue()};
|
||||
SmallVector<int64_t, 8> output_shape(/*Size=*/1, op.num_elements());
|
||||
for (const auto &dim : dense_elem_attr.getIntValues())
|
||||
output_shape.push_back(dim.getSExtValue());
|
||||
RankedTensorType result_type =
|
||||
|
@ -216,8 +216,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
// Use the min/max from the operands and the num_bits and narrow_range
|
||||
// attribute to create the quantization parameter for the new quantize op.
|
||||
rewriter.setInsertionPointAfter(tf_op);
|
||||
IntegerAttr num_bits =
|
||||
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
|
||||
IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits());
|
||||
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
|
||||
Type res_type = tf_op.getType();
|
||||
TypeAttr qtype = quant::GetQuantizedTypeAttr(
|
||||
@ -538,8 +537,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
loc, new_output_type, original_input, shape);
|
||||
|
||||
// Replace the original strided_slice.
|
||||
llvm::APInt new_begin_mask = strided_slice_op.begin_mask();
|
||||
llvm::APInt new_end_mask = strided_slice_op.end_mask();
|
||||
uint64_t new_begin_mask = strided_slice_op.begin_mask();
|
||||
uint64_t new_end_mask = strided_slice_op.end_mask();
|
||||
// Since we expand the dims, we need to apply them to the begin_mask &
|
||||
// end_mask.
|
||||
new_begin_mask |= strided_slice_op.new_axis_mask();
|
||||
@ -602,8 +601,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
|
||||
const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
|
||||
|
||||
int64_t begin_mask = strided_slice_op.begin_mask().getSExtValue();
|
||||
int64_t end_mask = strided_slice_op.end_mask().getSExtValue();
|
||||
int64_t begin_mask = strided_slice_op.begin_mask();
|
||||
int64_t end_mask = strided_slice_op.end_mask();
|
||||
int64_t new_begin_mask = 0;
|
||||
int64_t new_end_mask = 0;
|
||||
|
||||
@ -684,16 +683,16 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
|
||||
// TODO(renjieliu): Consider expand the transformation for shrink mask as
|
||||
// well.
|
||||
if (strided_slice_op.shrink_axis_mask().getZExtValue()) return failure();
|
||||
if (strided_slice_op.shrink_axis_mask()) return failure();
|
||||
|
||||
// Handle new axis mask.
|
||||
uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
|
||||
uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
|
||||
if (new_axis_mask != 0) {
|
||||
return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter);
|
||||
}
|
||||
|
||||
// Handle ellipsis mask.
|
||||
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask().getZExtValue();
|
||||
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
|
||||
if (ellipsis_mask != 0) {
|
||||
return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter);
|
||||
}
|
||||
|
@ -369,7 +369,7 @@ void Print(ReplicateOp op, OpAsmPrinter* p) {
|
||||
// [%a, ...] as %block_arg0: type
|
||||
// packed_input
|
||||
// %b as %block_arg1: type
|
||||
const int32_t n = op.n().getSExtValue();
|
||||
const int32_t n = op.n();
|
||||
const int32_t num_replicated_inputs =
|
||||
(*op.operand_segment_sizes().int_value_begin()).getSExtValue();
|
||||
const int32_t num_replicated_block_args = num_replicated_inputs / n;
|
||||
@ -413,7 +413,7 @@ LogicalResult VerifyCompatibleTypes(Type a, Type b) {
|
||||
}
|
||||
|
||||
LogicalResult Verify(ReplicateOp op) {
|
||||
int32_t n = op.n().getSExtValue();
|
||||
int32_t n = op.n();
|
||||
|
||||
// Check number of devices, if set, matches `n`.
|
||||
if (op.devices().hasValue()) {
|
||||
|
@ -190,7 +190,7 @@ void BatchMatMulV2Op::getCanonicalizationPatterns(
|
||||
|
||||
static LogicalResult Verify(BatchToSpaceOp op) {
|
||||
// Op already has a constraint that block_size >= 2.
|
||||
int64_t block_size = op.block_size().getSExtValue();
|
||||
int64_t block_size = op.block_size();
|
||||
|
||||
llvm::SmallVector<int64_t, 4> input_shape(4, ShapedType::kDynamicSize);
|
||||
auto input_type = op.input().getType().cast<TensorType>();
|
||||
@ -1639,7 +1639,7 @@ static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) {
|
||||
return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) +
|
||||
"," + Twine(std::to_string(rmax)) + "]");
|
||||
}
|
||||
int64_t num_bits = op.num_bits().getSExtValue();
|
||||
int64_t num_bits = op.num_bits();
|
||||
if (num_bits < 2 || num_bits > 16) {
|
||||
return op.emitOpError(
|
||||
"requires num_bits to be between 2 and 16, inclusive");
|
||||
@ -1659,7 +1659,7 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) {
|
||||
if (max && !IsOfRankedFloatTensorType(max, 0))
|
||||
return op.emitOpError("requires max to be a 0d float tensor");
|
||||
|
||||
int64_t num_bits = op.num_bits().getSExtValue();
|
||||
int64_t num_bits = op.num_bits();
|
||||
if (num_bits < 2 || num_bits > 16) {
|
||||
return op.emitOpError(
|
||||
"requires num_bits to be between 2 and 16, inclusive");
|
||||
@ -1683,7 +1683,7 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) {
|
||||
if (!HasRankAtLeast(inputs, 1))
|
||||
return op.emitError("requires inputs to be at least 1d float tensor");
|
||||
|
||||
int64_t num_bits = op.num_bits().getSExtValue();
|
||||
int64_t num_bits = op.num_bits();
|
||||
if (num_bits < 2 || num_bits > 16) {
|
||||
return op.emitOpError(
|
||||
"requires num_bits to be between 2 and 16, inclusive");
|
||||
@ -1886,7 +1886,7 @@ StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(GatherV2Op op) {
|
||||
int64_t batch_dims = op.batch_dims().getSExtValue();
|
||||
int64_t batch_dims = op.batch_dims();
|
||||
if (auto ty = op.indices().getType().dyn_cast<RankedTensorType>()) {
|
||||
int64_t rank = ty.getRank();
|
||||
if (batch_dims > rank || batch_dims < -rank)
|
||||
|
@ -109,7 +109,7 @@ void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(OneHotOp op) {
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
int64_t axis = op.axis();
|
||||
|
||||
auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
|
||||
if (indices_ty &&
|
||||
@ -207,7 +207,7 @@ static LogicalResult Verify(PackOp op) {
|
||||
// the axis value range is [-(R+1), R+1).
|
||||
int64_t range_begin = -inputs_rank - 1; // Inclusive
|
||||
int64_t range_end = inputs_rank + 1; // Exclusive
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
int64_t axis = op.axis();
|
||||
if (axis < range_begin || axis >= range_end) {
|
||||
return op.emitError() << "attribute 'axis' should be within range ["
|
||||
<< range_begin << ", " << range_end
|
||||
@ -232,7 +232,7 @@ OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (values().size() < 2) return {};
|
||||
|
||||
// Dimensions packed along axis = 0 (pack scalars into vector).
|
||||
if (axis().getSExtValue() != 0) return {};
|
||||
if (axis() != 0) return {};
|
||||
|
||||
// First packed value is defined by a strided slice operation.
|
||||
auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
|
||||
@ -247,11 +247,9 @@ OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
|
||||
// scalar value from input vector).
|
||||
if (slice_op.begin_mask().getSExtValue() != 0 ||
|
||||
slice_op.ellipsis_mask().getSExtValue() != 0 ||
|
||||
slice_op.end_mask().getSExtValue() != 0 ||
|
||||
slice_op.new_axis_mask().getSExtValue() != 0 ||
|
||||
slice_op.shrink_axis_mask().getSExtValue() != 1)
|
||||
if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 ||
|
||||
slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 ||
|
||||
slice_op.shrink_axis_mask() != 1)
|
||||
return {};
|
||||
|
||||
// Returns a value if the `value` is defined by a ConstOp with a single
|
||||
@ -1396,7 +1394,7 @@ static LogicalResult VerifyStridedSliceBase(OpTy op) {
|
||||
|
||||
// Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there
|
||||
// exists only no more than one ellipsis.
|
||||
uint32_t ellipsis_mask = op.ellipsis_mask().getZExtValue();
|
||||
uint32_t ellipsis_mask = op.ellipsis_mask();
|
||||
if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask))
|
||||
return op.emitOpError("cannot have multiple ellipses");
|
||||
|
||||
@ -1652,10 +1650,9 @@ bool StridedSliceOp::GetSlicedBoundRanges(
|
||||
sparse_strides.push_back(stride.getSExtValue());
|
||||
|
||||
CalculateSlicedShapeFromSparseIndices(
|
||||
input_shape, sparse_begin, sparse_end, sparse_strides,
|
||||
begin_mask().getZExtValue(), end_mask().getZExtValue(),
|
||||
ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(),
|
||||
shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride);
|
||||
input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
|
||||
end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
|
||||
slice_begin, slice_end, slice_stride);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1706,10 +1703,9 @@ bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
|
||||
sparse_strides.push_back(stride.getSExtValue());
|
||||
|
||||
CalculateSlicedShapeFromSparseIndices(
|
||||
*input_shape, sparse_begin, sparse_end, sparse_strides,
|
||||
begin_mask().getZExtValue(), end_mask().getZExtValue(),
|
||||
ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(),
|
||||
shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride);
|
||||
*input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
|
||||
end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
|
||||
slice_begin, slice_end, slice_stride);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -2090,7 +2086,7 @@ static LogicalResult Verify(UnpackOp op) {
|
||||
if (!value_type) return success();
|
||||
|
||||
int64_t value_rank = value_type.getRank();
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
int64_t axis = op.axis();
|
||||
if (axis < -value_rank || axis >= value_rank)
|
||||
return op.emitOpError("axis attribute must be in the range of [-")
|
||||
<< value_rank << ", " << value_rank << ')';
|
||||
|
@ -88,7 +88,7 @@ class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
|
||||
const int input_channels =
|
||||
conv_op.lhs().getType().cast<ShapedType>().getDimSize(
|
||||
input_feature_dimension);
|
||||
int feature_group_count = conv_op.feature_group_count().getSExtValue();
|
||||
int feature_group_count = conv_op.feature_group_count();
|
||||
|
||||
const bool is_depthwise_conv = input_channels == feature_group_count;
|
||||
std::string padding;
|
||||
|
@ -303,7 +303,7 @@ class LowerDynamicStitchOp : public OpRewritePattern<TF::DynamicStitchOp> {
|
||||
reshaped_data.getType().cast<RankedTensorType>().getShape()[0];
|
||||
auto items = rewriter.create<UnpackOp>(
|
||||
loc, SmallVector<Type, 4>(num_items, item_ty), reshaped_data,
|
||||
/*axis=*/APInt(64, 0));
|
||||
/*axis=*/0);
|
||||
for (auto index_item : llvm::zip(index_attr, items.getResults())) {
|
||||
int64_t output_index = std::get<0>(index_item).getSExtValue();
|
||||
Value item = std::get<1>(index_item);
|
||||
@ -399,7 +399,7 @@ class LowerPackOp : public OpRewritePattern<TF::PackOp> {
|
||||
loc,
|
||||
DenseElementsAttr::get(
|
||||
RankedTensorType::get({}, rewriter.getIntegerType(64)), op.axis()));
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
int64_t axis = op.axis();
|
||||
|
||||
Type prev_input_ty, inferred_ty;
|
||||
SmallVector<Value, 4> expanded_inputs;
|
||||
|
@ -151,7 +151,7 @@ bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) {
|
||||
// invariant. Shape ops are rewritten to be invariant when possible, prior to
|
||||
// hoisting ops.
|
||||
void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) {
|
||||
const int num_replicas = replicate_op.n().getLimitedValue();
|
||||
const int num_replicas = replicate_op.n();
|
||||
Block* replicate_block = &replicate_op.GetBody();
|
||||
|
||||
replicate_op.walk([&](TF::ShapeOp shape_op) {
|
||||
|
@ -376,7 +376,7 @@ LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
|
||||
tf_executor::IslandOp island_op,
|
||||
tf_device::ReplicateOp replicate_op) {
|
||||
OpBuilder builder(island_op);
|
||||
const int num_replicas = replicate_op.n().getLimitedValue();
|
||||
const int num_replicas = replicate_op.n();
|
||||
|
||||
// Create islands per replica.
|
||||
llvm::SmallVector<tf_executor::IslandOp, 8> replicas;
|
||||
|
@ -322,8 +322,7 @@ LogicalResult SortTPUReplicatedInputsByIndex(
|
||||
llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
|
||||
llvm::SmallDenseSet<int64_t, 8> unique_indices;
|
||||
for (Operation* input : inputs) {
|
||||
int64_t index =
|
||||
llvm::cast<TF::TPUReplicatedInputOp>(input).index().getSExtValue();
|
||||
int64_t index = llvm::cast<TF::TPUReplicatedInputOp>(input).index();
|
||||
if (index < -1)
|
||||
return input->emitOpError()
|
||||
<< "requires index to be at least -1, but got " << index;
|
||||
@ -342,10 +341,8 @@ LogicalResult SortTPUReplicatedInputsByIndex(
|
||||
std::stable_sort(
|
||||
sorted_inputs->begin(), sorted_inputs->end(),
|
||||
[](Operation* l, Operation* r) {
|
||||
int64_t l_index =
|
||||
llvm::cast<TF::TPUReplicatedInputOp>(l).index().getSExtValue();
|
||||
int64_t r_index =
|
||||
llvm::cast<TF::TPUReplicatedInputOp>(r).index().getSExtValue();
|
||||
int64_t l_index = llvm::cast<TF::TPUReplicatedInputOp>(l).index();
|
||||
int64_t r_index = llvm::cast<TF::TPUReplicatedInputOp>(r).index();
|
||||
if (l_index == -1 && r_index != -1) return false;
|
||||
if (r_index == -1 && l_index != -1) return true;
|
||||
return l_index < r_index;
|
||||
@ -401,8 +398,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
||||
return input->emitOpError() << "requires " << num_inputs << " operands";
|
||||
|
||||
auto tpu_replicated_input = llvm::cast<TF::TPUReplicatedInputOp>(input);
|
||||
int64_t tpu_replicated_input_index =
|
||||
tpu_replicated_input.index().getSExtValue();
|
||||
int64_t tpu_replicated_input_index = tpu_replicated_input.index();
|
||||
if (is_packed) {
|
||||
packed_inputs.push_back(input->getOperand(0));
|
||||
packed_input_indices.push_back(tpu_replicated_input_index);
|
||||
|
@ -185,7 +185,7 @@ bool HandleReplicatedInputs(
|
||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
|
||||
// We need to know the devices to copy to.
|
||||
if (!replicate.devices()) return false;
|
||||
int64_t num_replicas = replicate.n().getZExtValue();
|
||||
int64_t num_replicas = replicate.n();
|
||||
auto inputs = replicate.getOperands()
|
||||
.drop_front(replicate_arg_index * num_replicas)
|
||||
.take_front(num_replicas);
|
||||
|
@ -210,8 +210,9 @@ Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info,
|
||||
|
||||
// Creates a WhileRegionOp cond and body regions with yield op and
|
||||
// an empty body.
|
||||
TF::WhileRegionOp CloneEmptyWhile(bool is_stateless, APInt parallel_iterations,
|
||||
Location loc, OpBuilder* builder) {
|
||||
TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
|
||||
uint64_t parallel_iterations, Location loc,
|
||||
OpBuilder* builder) {
|
||||
auto host_side_while = builder->create<TF::WhileRegionOp>(
|
||||
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
|
||||
is_stateless, parallel_iterations);
|
||||
|
@ -650,7 +650,7 @@ LogicalResult Rewrite(
|
||||
int num_replicas = 1;
|
||||
tf_device::ReplicateOp replicate =
|
||||
cluster_func.getParentOfType<tf_device::ReplicateOp>();
|
||||
if (replicate) num_replicas = replicate.n().getLimitedValue();
|
||||
if (replicate) num_replicas = replicate.n();
|
||||
|
||||
auto num_cores_per_replica_attr = cluster_func.getAttrOfType<IntegerAttr>(
|
||||
tensorflow::kNumCoresPerReplicaAttr);
|
||||
|
@ -432,9 +432,8 @@ TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,
|
||||
input_shape[3] * block_size * block_size};
|
||||
auto transform_result_type =
|
||||
RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
|
||||
return builder.create<TF::SpaceToDepthOp>(cluster_func.getLoc(),
|
||||
transform_result_type, input,
|
||||
APInt(64, block_size));
|
||||
return builder.create<TF::SpaceToDepthOp>(
|
||||
cluster_func.getLoc(), transform_result_type, input, block_size);
|
||||
}
|
||||
|
||||
// Performs transformation for a non-replicated input.
|
||||
@ -458,7 +457,7 @@ bool HandleHostReplicatedInputs(int64_t index,
|
||||
int64_t replicate_arg_index = block_arg.getArgNumber();
|
||||
// We need to know the devices to copy to.
|
||||
if (!replicate.devices()) return false;
|
||||
int64_t num_replicas = replicate.n().getZExtValue();
|
||||
int64_t num_replicas = replicate.n();
|
||||
// Gets inputs at replicate_arg_index for each replica.
|
||||
auto inputs = replicate.getOperands()
|
||||
.drop_front(replicate_arg_index * num_replicas)
|
||||
@ -669,7 +668,6 @@ void TPUSpaceToDepthPass::runOnOperation() {
|
||||
if (!device_func) return;
|
||||
|
||||
TF::Conv2DOp first_conv;
|
||||
Optional<ArrayRef<int64_t>> input_shape;
|
||||
// A map maps block argument id to the convolutions consumes them.
|
||||
llvm::SmallDenseMap<unsigned, std::vector<Conv2DWithBlockSize>>
|
||||
argnum_and_convolutions;
|
||||
|
@ -174,7 +174,7 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
||||
assert(metadata_str && "Missing compilation metadata");
|
||||
tensorflow::tpu::TPUCompileMetadataProto metadata;
|
||||
metadata.ParseFromString(std::string(metadata_str.getValue()));
|
||||
int64_t num_replicas = replicate.n().getLimitedValue();
|
||||
int64_t num_replicas = replicate.n();
|
||||
// Find the formattable operands of `execute`, which must be mirrored
|
||||
// variables (arguments of `replicate`), and must be pass-throughs from while
|
||||
// operands.
|
||||
@ -264,7 +264,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
|
||||
tf_device::ReplicateOp replicate, ArrayRef<Value> new_inputs,
|
||||
const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
|
||||
devices) {
|
||||
int64_t num_replicas = replicate.n().getLimitedValue();
|
||||
int64_t num_replicas = replicate.n();
|
||||
assert(new_inputs.size() == num_replicas);
|
||||
|
||||
// As model parallelism is not yet supported, we assume that all ops are
|
||||
@ -423,7 +423,7 @@ void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
|
||||
// Performs the transformation for a replicate op inside a while loop.
|
||||
void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
MLIRContext* context) {
|
||||
int64_t num_replicas = replicate.n().getLimitedValue();
|
||||
int64_t num_replicas = replicate.n();
|
||||
if (num_replicas == 1) return;
|
||||
tf_device::LaunchOp execute_launch;
|
||||
for (auto execute_launch_op :
|
||||
|
@ -106,6 +106,9 @@ static mlir::LogicalResult GetXlaOp(
|
||||
// TODO(hpucha): This should be consolidated into a general place.
|
||||
static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
|
||||
|
||||
static uint32_t Convertuint32_t(uint32_t i) { return i; }
|
||||
static uint64_t Convertuint64_t(uint64_t i) { return i; }
|
||||
|
||||
// Convert APFloat to double.
|
||||
static double ConvertAPFloat(llvm::APFloat value) {
|
||||
const auto& semantics = value.getSemantics();
|
||||
@ -783,7 +786,7 @@ LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
|
||||
LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
|
||||
op.iota_dimension().getSExtValue());
|
||||
op.iota_dimension());
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -909,8 +912,8 @@ LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) {
|
||||
auto result = op.getResult();
|
||||
auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()];
|
||||
auto xla_result = xla::RngBitGenerator(
|
||||
static_cast<xla::RandomAlgorithm>(op.rng_algorithm().getSExtValue()),
|
||||
Unwrap(xla_arg_1), xla::TypeToShape(result.getType()).tuple_shapes(1));
|
||||
static_cast<xla::RandomAlgorithm>(op.rng_algorithm()), Unwrap(xla_arg_1),
|
||||
xla::TypeToShape(result.getType()).tuple_shapes(1));
|
||||
value_map[result] = xla_result;
|
||||
return mlir::success();
|
||||
}
|
||||
@ -1007,7 +1010,7 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
|
||||
|
||||
auto& value_map = *ctx.values;
|
||||
value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator,
|
||||
op.dimension().getSExtValue(), op.is_stable());
|
||||
op.dimension(), op.is_stable());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1829,7 +1829,7 @@ class ConvertFusedBatchNormGradBase
|
||||
|
||||
auto training_op = rewriter.create<BatchNormGradOp>(
|
||||
loc, result_type, act, scale, mean, var, grad, op.epsilon(),
|
||||
feature_dim_attr.getValue());
|
||||
feature_dim);
|
||||
|
||||
x_backprop =
|
||||
rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 0);
|
||||
@ -1949,7 +1949,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
|
||||
|
||||
auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
|
||||
op.epsilon(), feature_dim.getValue());
|
||||
op.epsilon(), feature_dim.getInt());
|
||||
// HLO op outputs a tuple of tensors. Extract those results.
|
||||
auto bn_train_op_result = bn_train_op.getResult();
|
||||
Value y_out = rewriter.create<mhlo::GetTupleElementOp>(
|
||||
@ -2036,7 +2036,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
|
||||
op.getLoc(),
|
||||
/*result_type=*/bn_train_input_type_tensor, bn_train_input,
|
||||
op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(),
|
||||
feature_dim.getValue());
|
||||
feature_dim.getInt());
|
||||
|
||||
// Convert back to input type to stay aligned with expected output type
|
||||
// for TF op.
|
||||
@ -3191,7 +3191,7 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
|
||||
// axis. For instance, if there are 4 dims, we can support a
|
||||
// shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no
|
||||
// other.
|
||||
bool shrink_axis_mask_ok = op.shrink_axis_mask().isMask();
|
||||
bool shrink_axis_mask_ok = llvm::isMask_64(op.shrink_axis_mask());
|
||||
if (!shrink_axis_mask_ok)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
@ -3200,27 +3200,27 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
|
||||
|
||||
// When begin/end values are dynamic, the ellipsis mask, if set, must refer
|
||||
// to the last dimension.
|
||||
int ellipsis_mask = op.ellipsis_mask().getZExtValue();
|
||||
int ellipsis_mask = op.ellipsis_mask();
|
||||
if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"requires that ellipsis_mask, if set, refer to the last dimension of "
|
||||
"input (when begin/end values are dynamic)");
|
||||
|
||||
APInt begin_mask = op.begin_mask();
|
||||
if (!begin_mask.isNullValue())
|
||||
uint64_t begin_mask = op.begin_mask();
|
||||
if (begin_mask)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"requires that begin_mask is either set to 0 or not set when "
|
||||
"begin/end values are dynamic");
|
||||
APInt end_mask = op.end_mask();
|
||||
if (!end_mask.isNullValue())
|
||||
uint64_t end_mask = op.end_mask();
|
||||
if (end_mask)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"requires that end_mask is either set to 0 or not set when begin/end "
|
||||
"values are dynamic");
|
||||
APInt new_axis_mask = op.new_axis_mask();
|
||||
if (!new_axis_mask.isNullValue())
|
||||
uint64_t new_axis_mask = op.new_axis_mask();
|
||||
if (new_axis_mask)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"requires that new_axis_mask is either set to 0 or not set when "
|
||||
@ -4476,7 +4476,7 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
|
||||
}
|
||||
|
||||
int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
int64_t axis = op.axis();
|
||||
if (axis == -1) axis = indices_shape.size();
|
||||
|
||||
llvm::SmallVector<int64_t, 4> broadcast_dims(indices_shape.size());
|
||||
@ -4752,7 +4752,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
|
||||
if (!value_type) return failure();
|
||||
|
||||
int64_t value_rank = value_type.getRank();
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
int64_t axis = op.axis();
|
||||
if (axis < 0) axis += value_rank;
|
||||
|
||||
// Parameters for constructing each slice.
|
||||
|
@ -1418,7 +1418,7 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
|
||||
Shape keys_shape = operand_shapes[0];
|
||||
int64 dimension_to_sort = sort_op.dimension().getSExtValue();
|
||||
int64 dimension_to_sort = sort_op.dimension();
|
||||
for (int64 i = 0; i < operand_count; ++i) {
|
||||
// We assume that the layout of all involved operands and outputs is the
|
||||
// same.
|
||||
|
@ -722,8 +722,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
)
|
||||
|
||||
# Check out LLVM and MLIR from llvm-project.
|
||||
LLVM_COMMIT = "5ffd940ac02a8e000691c45a6dc4f69d0198e675"
|
||||
LLVM_SHA256 = "fb0e839b6ece41bcd028683ce7e4e063b159cd85cd42502141a4391e02cefe36"
|
||||
LLVM_COMMIT = "202766947edb5407b84484185608aac077085608"
|
||||
LLVM_SHA256 = "7b739119481c4adaf513e41c3221ac3a96e80823bf36951707f593d3827d1f4e"
|
||||
LLVM_URLS = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||
|
Loading…
x
Reference in New Issue
Block a user