Integrate LLVM at llvm/llvm-project@1b97cdf885
Updates LLVM usage to match [1b97cdf885d6](https://github.com/llvm/llvm-project/commit/1b97cdf885d6) PiperOrigin-RevId: 348587513 Change-Id: I853d197b33c5df08c00c99ddc8cf8b2681bed708
This commit is contained in:
parent
682d130cf6
commit
94155a3934
@ -202,7 +202,7 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||
Type element_type = IntegerType::get(1, context);
|
||||
Type element_type = IntegerType::get(context, 1);
|
||||
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
|
||||
attributes, element_type,
|
||||
inferedReturnShapes);
|
||||
|
@ -621,7 +621,7 @@ OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
||||
static LogicalResult Verify(TupleOp op) {
|
||||
SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
|
||||
op.operand_type_end()};
|
||||
auto expectedType = TupleType::get(operandTypes, op.getContext());
|
||||
auto expectedType = TupleType::get(op.getContext(), operandTypes);
|
||||
if (op.getType() != expectedType) {
|
||||
return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
|
||||
op.getType(), expectedType));
|
||||
@ -1967,7 +1967,7 @@ LogicalResult ReplicaIdOp::inferReturnTypes(
|
||||
MLIRContext* context, Optional<Location>, ValueRange operands,
|
||||
DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(RankedTensorType::get(
|
||||
/*shape=*/{}, IntegerType::get(32, IntegerType::Unsigned, context)));
|
||||
/*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -145,7 +145,7 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
|
||||
|
||||
auto int_shape_type = RankedTensorType::get(
|
||||
output_type.getShape(),
|
||||
IntegerType::get(bitwidth, rewriter.getContext()));
|
||||
IntegerType::get(rewriter.getContext(), bitwidth));
|
||||
auto loc = op.getLoc();
|
||||
auto integer_const = rewriter.create<mlir::ConstantOp>(
|
||||
loc, DenseIntElementsAttr::get(int_shape_type, values));
|
||||
|
@ -37,10 +37,10 @@ constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
|
||||
|
||||
DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
|
||||
f32_ = FloatType::getF32(ctx_);
|
||||
i8_ = IntegerType::get(k8Bits, ctx_);
|
||||
i8_ = IntegerType::get(ctx_, k8Bits);
|
||||
i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
|
||||
i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
|
||||
i32_ = IntegerType::get(k32Bits, ctx_);
|
||||
i32_ = IntegerType::get(ctx_, k32Bits);
|
||||
i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
|
||||
i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
|
||||
any_ = AnyQuantizedType();
|
||||
|
@ -212,7 +212,7 @@ LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
|
||||
}
|
||||
|
||||
*attr_i32 = IntegerAttr::get(
|
||||
IntegerType::get(/*width=*/32, attr.getContext()), value);
|
||||
IntegerType::get(attr.getContext(), /*width=*/32), value);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -547,8 +547,8 @@ struct ConvertTensorListResize
|
||||
Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
|
||||
size_diff.getType(), size.getType()};
|
||||
Type branch_result_type[] = {result_type};
|
||||
auto func_type = FunctionType::get(branch_args_type, branch_result_type,
|
||||
rewriter.getContext());
|
||||
auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
|
||||
branch_result_type);
|
||||
|
||||
// Constructs `then_branch`, which is executed when `if_cond` evaluates to
|
||||
// true.
|
||||
@ -775,8 +775,8 @@ LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
// Change `func`'s argument type to `unranked_argument_types`. If it
|
||||
// return types contain a `DT_VARIANT`, change it to the unranked type
|
||||
// derived from the corresponding argument.
|
||||
func.setType(FunctionType::get(updated_argument_types, updated_result_types,
|
||||
op.getContext()));
|
||||
func.setType(FunctionType::get(op.getContext(), updated_argument_types,
|
||||
updated_result_types));
|
||||
|
||||
// Change the argument type for the first block.
|
||||
llvm::for_each(func.getArguments(), [&](BlockArgument &arg) {
|
||||
|
@ -243,7 +243,7 @@ DenseElementsAttr GetShape(Value output_val) {
|
||||
return mlir::DenseElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int>(shape.size())},
|
||||
mlir::IntegerType::get(32, output_val.getContext())),
|
||||
mlir::IntegerType::get(output_val.getContext(), 32)),
|
||||
llvm::makeArrayRef(shape));
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,7 @@ void UpdateFuncType(FuncOp func) {
|
||||
if (llvm::makeArrayRef(return_types) == func_type.getResults()) return;
|
||||
|
||||
auto updated_type =
|
||||
FunctionType::get(func_type.getInputs(), return_types, func.getContext());
|
||||
FunctionType::get(func.getContext(), func_type.getInputs(), return_types);
|
||||
func.setType(updated_type);
|
||||
}
|
||||
|
||||
|
@ -134,12 +134,12 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
bool passthru_extra_args) {
|
||||
FunctionType type;
|
||||
if (passthru_extra_args) {
|
||||
type = FunctionType::get(types, types, &getContext());
|
||||
type = FunctionType::get(&getContext(), types, types);
|
||||
} else {
|
||||
SmallVector<Type, 4> result_types;
|
||||
auto operands = region.front().getTerminator()->getOperandTypes();
|
||||
result_types.append(operands.begin(), operands.end());
|
||||
type = FunctionType::get(types, result_types, &getContext());
|
||||
type = FunctionType::get(&getContext(), types, result_types);
|
||||
}
|
||||
|
||||
auto outlined_func = builder.create<FuncOp>(while_op.getLoc(), name, type);
|
||||
|
@ -382,8 +382,8 @@ void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
|
||||
auto input_types = fused_func_op_.getType().getInputs();
|
||||
auto output_type = mlir::RankedTensorType::get(
|
||||
output_shape, input_.getType().cast<RankedTensorType>().getElementType());
|
||||
fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
|
||||
fused_func_op_.getContext()));
|
||||
fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(),
|
||||
input_types, output_type));
|
||||
}
|
||||
|
||||
LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
|
||||
@ -820,8 +820,8 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
||||
}
|
||||
|
||||
// Update function signatures.
|
||||
func_op.setType(mlir::FunctionType::get(func_op.getType().getInputs(),
|
||||
output_types, func_op.getContext()));
|
||||
func_op.setType(mlir::FunctionType::get(
|
||||
func_op.getContext(), func_op.getType().getInputs(), output_types));
|
||||
|
||||
builder->create<mlir::ReturnOp>(func_op.getLoc(), outputs);
|
||||
return success();
|
||||
|
@ -33,7 +33,7 @@ void init_types(py::module& m) {
|
||||
.def("getF64", &mlir::FloatType::getF64);
|
||||
|
||||
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
|
||||
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
|
||||
.def("get", py::overload_cast<mlir::MLIRContext*, unsigned>(
|
||||
&mlir::IntegerType::get));
|
||||
|
||||
py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")
|
||||
|
@ -668,7 +668,7 @@ Status MlirFunctionContext::Finalize(OutputList* outputs,
|
||||
|
||||
auto arg_types = body.getArgumentTypes();
|
||||
auto result_types = body.getTerminator()->getOperandTypes();
|
||||
func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
|
||||
func_.setType(FunctionType::get(func_.getContext(), arg_types, result_types));
|
||||
*f = new MlirFunction(std::move(context_), std::move(module_), func_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -1310,7 +1310,7 @@ LogicalResult ConcatOffsetOp::fold(ArrayRef<Attribute> operands,
|
||||
results.reserve(shapes.size());
|
||||
SmallVector<int32_t, 4> cumulative_sum(num_dims, 0);
|
||||
RankedTensorType offset_type =
|
||||
RankedTensorType::get({num_dims}, IntegerType::get(32, getContext()));
|
||||
RankedTensorType::get({num_dims}, IntegerType::get(getContext(), 32));
|
||||
for (DenseIntElementsAttr shape : shapes) {
|
||||
results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum));
|
||||
cumulative_sum[concat_dim] += shape.getValue<int32_t>(concat_dim);
|
||||
|
@ -898,7 +898,7 @@ static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
|
||||
dimensions.push_back(APInt(out_width, shape[i]));
|
||||
|
||||
auto result_type = RankedTensorType::get(
|
||||
{rank}, IntegerType::get(out_width, input_ty.getContext()));
|
||||
{rank}, IntegerType::get(input_ty.getContext(), out_width));
|
||||
return DenseElementsAttr::get(result_type, dimensions);
|
||||
}
|
||||
|
||||
|
@ -155,19 +155,19 @@ Type TensorFlowRefType::RemoveRef() {
|
||||
if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
|
||||
if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
|
||||
if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
|
||||
if (isa<BoolRefType>()) return mlir::IntegerType::get(1, ctx);
|
||||
if (isa<Int8RefType>()) return mlir::IntegerType::get(8, ctx);
|
||||
if (isa<Int16RefType>()) return mlir::IntegerType::get(16, ctx);
|
||||
if (isa<Int32RefType>()) return mlir::IntegerType::get(32, ctx);
|
||||
if (isa<Int64RefType>()) return mlir::IntegerType::get(64, ctx);
|
||||
if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
|
||||
if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
|
||||
if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
|
||||
if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
|
||||
if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
|
||||
if (isa<Uint8RefType>())
|
||||
return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
|
||||
return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
|
||||
if (isa<Uint16RefType>())
|
||||
return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
|
||||
return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
|
||||
if (isa<Uint32RefType>())
|
||||
return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
|
||||
return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
|
||||
if (isa<Uint64RefType>())
|
||||
return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
|
||||
return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
|
||||
if (isa<Complex64RefType>())
|
||||
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
|
||||
if (isa<Complex128RefType>())
|
||||
|
@ -58,8 +58,8 @@ FuncOp BuildFunction(llvm::ArrayRef<Value> live_ins,
|
||||
operand_types.reserve(live_ins.size());
|
||||
for (Value v : live_ins) operand_types.emplace_back(v.getType());
|
||||
|
||||
auto func_type = FunctionType::get(operand_types, cluster_op.getResultTypes(),
|
||||
builder->getContext());
|
||||
auto func_type =
|
||||
builder->getFunctionType(operand_types, cluster_op.getResultTypes());
|
||||
|
||||
// TODO(lyandy): Define better name for outlined function. Potentially some
|
||||
// name can be added during cluster formation.
|
||||
|
@ -193,7 +193,7 @@ void CreateFunctions(ModuleOp module_op,
|
||||
std::replace(func_name.begin(), func_name.end(), '/', '_');
|
||||
|
||||
FunctionType func_type =
|
||||
FunctionType::get(input_types, result_types, context);
|
||||
FunctionType::get(context, input_types, result_types);
|
||||
Location loc = metadata.ops.front()->getLoc();
|
||||
FuncOp func_op = FuncOp::create(loc, func_name, func_type);
|
||||
// Sets the device attribute for every input and every result of the
|
||||
|
@ -53,7 +53,7 @@ Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
|
||||
values.reserve(rank);
|
||||
for (int i = 0; i < rank; ++i) values.push_back(APInt(bitwidth, r1[i]));
|
||||
auto result_type = RankedTensorType::get(
|
||||
{rank}, IntegerType::get(bitwidth, builder.getContext()));
|
||||
{rank}, IntegerType::get(builder.getContext(), bitwidth));
|
||||
return builder.create<TF::ConstOp>(
|
||||
loc, DenseElementsAttr::get(result_type, values));
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
||||
for (Value operand : island_op.GetYield().getOperands())
|
||||
func_result_types.push_back(operand.getType());
|
||||
FunctionType func_type =
|
||||
FunctionType::get(func_operand_types, func_result_types, ctx);
|
||||
FunctionType::get(ctx, func_operand_types, func_result_types);
|
||||
|
||||
// Create the outlined function
|
||||
SmallString<32> name = kOutlinedFuncPrefix;
|
||||
|
@ -213,9 +213,9 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) {
|
||||
}
|
||||
|
||||
// Update the function type.
|
||||
func.setType(mlir::FunctionType::get(func.getArgumentTypes(),
|
||||
func.getType().getResults(),
|
||||
module.getContext()));
|
||||
func.setType(mlir::FunctionType::get(module.getContext(),
|
||||
func.getArgumentTypes(),
|
||||
func.getType().getResults()));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -180,8 +180,8 @@ mlir::LogicalResult PromoteVarHandlesToArguments(
|
||||
}
|
||||
|
||||
if (!var_handle_shared_names->empty())
|
||||
function.setType(FunctionType::get(func_arg_types, func_type.getResults(),
|
||||
function.getContext()));
|
||||
function.setType(FunctionType::get(function.getContext(), func_arg_types,
|
||||
func_type.getResults()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ void ExtractSingleBlockRegion(Region& region, StringRef name,
|
||||
if (extern_values_passthrough)
|
||||
for (auto input : extern_values) return_types.push_back(input.getType());
|
||||
|
||||
auto type = FunctionType::get(input_types, return_types, region.getContext());
|
||||
auto type = FunctionType::get(region.getContext(), input_types, return_types);
|
||||
|
||||
// Create new function and extract region body into the function.
|
||||
auto outlined_func = builder.create<FuncOp>(loc, name, type);
|
||||
|
@ -785,9 +785,9 @@ void RemoveUnusedResourceArgumentsAndForwardedRetvals(
|
||||
}
|
||||
}
|
||||
func_op.eraseArguments(indices_to_erase);
|
||||
func_op.setType(FunctionType::get(
|
||||
new_types, llvm::to_vector<4>(return_op->getOperandTypes()),
|
||||
func_op.getContext()));
|
||||
func_op.setType(
|
||||
FunctionType::get(func_op.getContext(), new_types,
|
||||
llvm::to_vector<4>(return_op->getOperandTypes())));
|
||||
}
|
||||
|
||||
// Lifts reads/writes of resource arguments from func_op and changes its
|
||||
@ -841,10 +841,9 @@ LogicalResult LiftArgRetResourcesForFunction(
|
||||
assign_variable_op.erase();
|
||||
}
|
||||
|
||||
func_op.setType(
|
||||
FunctionType::get(func_op.front().getArgumentTypes(),
|
||||
func_op.front().getTerminator()->getOperandTypes(),
|
||||
func_op.getContext()));
|
||||
func_op.setType(FunctionType::get(
|
||||
func_op.getContext(), func_op.front().getArgumentTypes(),
|
||||
func_op.front().getTerminator()->getOperandTypes()));
|
||||
|
||||
return success();
|
||||
}
|
||||
@ -1153,9 +1152,9 @@ LogicalResult HandlePartitionedCallOpCallee(
|
||||
auto new_return =
|
||||
builder.create<ReturnOp>(old_return->getLoc(), old_and_new_retvals);
|
||||
old_return->erase();
|
||||
callee.setType(FunctionType::get(
|
||||
callee.getType().getInputs(),
|
||||
llvm::to_vector<4>(new_return.getOperandTypes()), callee.getContext()));
|
||||
callee.setType(
|
||||
FunctionType::get(callee.getContext(), callee.getType().getInputs(),
|
||||
llvm::to_vector<4>(new_return.getOperandTypes())));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -184,9 +184,9 @@ void EliminateUnusedResultsForIfCase(Operation *op, ArrayRef<FuncOp> branches) {
|
||||
// Patch up function types (with less number of return values and potentially
|
||||
// less number of arguments)
|
||||
for (FuncOp func : cloned_branches) {
|
||||
func.setType(FunctionType::get(
|
||||
func.front().getArgumentTypes(),
|
||||
func.front().getTerminator()->getOperandTypes(), func.getContext()));
|
||||
func.setType(
|
||||
FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
|
||||
func.front().getTerminator()->getOperandTypes()));
|
||||
}
|
||||
|
||||
EliminateUnusedResults(op);
|
||||
@ -232,9 +232,9 @@ void EliminateUnusedResultsForWhile(TF::WhileOp op) {
|
||||
|
||||
// Patch up branch function types.
|
||||
for (FuncOp func : {cloned_cond, cloned_body}) {
|
||||
func.setType(FunctionType::get(
|
||||
func.front().getArgumentTypes(),
|
||||
func.front().getTerminator()->getOperandTypes(), func.getContext()));
|
||||
func.setType(
|
||||
FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
|
||||
func.front().getTerminator()->getOperandTypes()));
|
||||
}
|
||||
EliminateUnusedResults(op, &can_eliminate);
|
||||
}
|
||||
|
@ -1150,8 +1150,8 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
|
||||
}
|
||||
|
||||
FunctionType func_type = func.getType();
|
||||
func.setType(FunctionType::get(input_types, func_type.getResults(),
|
||||
func.getContext()));
|
||||
func.setType(FunctionType::get(func.getContext(), input_types,
|
||||
func_type.getResults()));
|
||||
|
||||
auto res =
|
||||
PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration);
|
||||
@ -1493,8 +1493,8 @@ void ShapeInference::InferShapeForFunctionReturnType(FuncOp func) {
|
||||
}
|
||||
|
||||
DCOMMENT("Updating function type");
|
||||
func.setType(FunctionType::get(
|
||||
func.getArgumentTypes(), return_op.getOperandTypes(), func.getContext()));
|
||||
func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(),
|
||||
return_op.getOperandTypes()));
|
||||
|
||||
if (changed) EnqueueCallers(func);
|
||||
}
|
||||
@ -1611,8 +1611,8 @@ LogicalResult InferShapeForFunction(FuncOp func,
|
||||
return failure();
|
||||
|
||||
context.InferShapeForFunctionReturnType(func);
|
||||
func.setType(FunctionType::get(new_arg_types, func.getType().getResults(),
|
||||
func.getContext()));
|
||||
func.setType(FunctionType::get(func.getContext(), new_arg_types,
|
||||
func.getType().getResults()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -137,9 +137,9 @@ void ModifyFunctionSignature(
|
||||
if (handle_new_size_vars) {
|
||||
handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
|
||||
}
|
||||
func.setType(FunctionType::get(
|
||||
new_input_types, func.front().getTerminator()->getOperandTypes(),
|
||||
func.getContext()));
|
||||
func.setType(
|
||||
FunctionType::get(func.getContext(), new_input_types,
|
||||
func.front().getTerminator()->getOperandTypes()));
|
||||
}
|
||||
|
||||
// Contains cached information for decomposed callee functions for (stateful)
|
||||
|
@ -460,10 +460,9 @@ LogicalResult HandleTensorArrayScatterV3Op(
|
||||
void UpdateFuncType(FuncOp func) {
|
||||
llvm::SmallVector<Type, 8> arg_types;
|
||||
for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
|
||||
func.setType(FunctionType::get(
|
||||
arg_types,
|
||||
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
|
||||
func.getContext()));
|
||||
func.setType(
|
||||
FunctionType::get(func.getContext(), arg_types,
|
||||
func.front().getTerminator()->getOperandTypes()));
|
||||
}
|
||||
|
||||
// Finds the accessed gradient sources for each tensor array argument.
|
||||
|
@ -71,9 +71,9 @@ struct TensorListOpsDecompositionPass
|
||||
void UpdateFuncType(FuncOp func) {
|
||||
llvm::SmallVector<Type, 8> arg_types;
|
||||
for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
|
||||
func.setType(FunctionType::get(
|
||||
arg_types, func.front().getTerminator()->getOperandTypes(),
|
||||
func.getContext()));
|
||||
func.setType(
|
||||
FunctionType::get(func.getContext(), arg_types,
|
||||
func.front().getTerminator()->getOperandTypes()));
|
||||
}
|
||||
|
||||
// Holds the size value of a tensor list and whether the size is statically
|
||||
|
@ -118,8 +118,8 @@ void TPUResourceReadForWrite::runOnOperation() {
|
||||
for (Value read_operand : read_operands)
|
||||
block.addArgument(read_operand.getType());
|
||||
|
||||
func.setType(FunctionType::get(block.getArgumentTypes(),
|
||||
func.getCallableResults(), &getContext()));
|
||||
func.setType(FunctionType::get(&getContext(), block.getArgumentTypes(),
|
||||
func.getCallableResults()));
|
||||
cluster_func.erase();
|
||||
}
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ struct TPUSpaceToDepthPass
|
||||
void UpdateFuncType(FuncOp func) {
|
||||
auto arg_types = func.front().getArgumentTypes();
|
||||
auto result_types = func.front().getTerminator()->getOperandTypes();
|
||||
func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
|
||||
func.setType(FunctionType::get(func.getContext(), arg_types, result_types));
|
||||
}
|
||||
|
||||
void HandleFuncOp(Operation* op) {
|
||||
@ -196,7 +196,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) {
|
||||
MLIRContext* context = conv2d.getContext();
|
||||
SmallVector<int64_t, 4> values = {1, 1, 1, 1};
|
||||
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
|
||||
return IntegerAttr::get(IntegerType::get(64, context), v);
|
||||
return IntegerAttr::get(IntegerType::get(context, 64), v);
|
||||
});
|
||||
// TODO(b/157276506): change type of strides to DenseElementsAttr
|
||||
auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
|
||||
@ -351,7 +351,7 @@ void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
|
||||
MLIRContext* context = backprop.getContext();
|
||||
SmallVector<int64_t, 4> values = {1, 1, 1, 1};
|
||||
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
|
||||
return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
|
||||
return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
|
||||
});
|
||||
auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
|
||||
|
||||
|
@ -1454,9 +1454,9 @@ Status ImporterBase::Convert(
|
||||
all_equal = false;
|
||||
}
|
||||
if (!all_equal) {
|
||||
function.setType(mlir::FunctionType::get(func_type.getInputs(),
|
||||
graph.getResultTypes(),
|
||||
function.getContext()));
|
||||
function.setType(mlir::FunctionType::get(function.getContext(),
|
||||
func_type.getInputs(),
|
||||
graph.getResultTypes()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -2906,8 +2906,8 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
|
||||
}
|
||||
new_input_types.push_back(arg.getType());
|
||||
}
|
||||
func.setType(mlir::FunctionType::get(
|
||||
new_input_types, func.getType().getResults(), module.getContext()));
|
||||
func.setType(mlir::FunctionType::get(module.getContext(), new_input_types,
|
||||
func.getType().getResults()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -569,9 +569,9 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
|
||||
for (mlir::BlockArgument& arg : main_fn.getArguments())
|
||||
updated_argument_types.push_back(arg.getType());
|
||||
|
||||
main_fn.setType(mlir::FunctionType::get(updated_argument_types,
|
||||
main_fn.getType().getResults(),
|
||||
main_fn.getContext()));
|
||||
main_fn.setType(mlir::FunctionType::get(main_fn.getContext(),
|
||||
updated_argument_types,
|
||||
main_fn.getType().getResults()));
|
||||
}
|
||||
|
||||
for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
|
||||
|
@ -136,30 +136,30 @@ TEST_F(ConvertTensorTest, Simple) {
|
||||
{1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context)));
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int8>(
|
||||
{1, -1}, DT_INT8, mlir::IntegerType::get(8, &context)));
|
||||
{1, -1}, DT_INT8, mlir::IntegerType::get(&context, 8)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int16>(
|
||||
{1, -1}, DT_INT16, mlir::IntegerType::get(16, &context)));
|
||||
{1, -1}, DT_INT16, mlir::IntegerType::get(&context, 16)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int32>(
|
||||
{1, -1}, DT_INT32, mlir::IntegerType::get(32, &context)));
|
||||
{1, -1}, DT_INT32, mlir::IntegerType::get(&context, 32)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int64>(
|
||||
{1, -1}, DT_INT64, mlir::IntegerType::get(64, &context)));
|
||||
{1, -1}, DT_INT64, mlir::IntegerType::get(&context, 64)));
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint8>(
|
||||
{1, 2}, DT_UINT8,
|
||||
mlir::IntegerType::get(
|
||||
8, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
&context, 8, mlir::IntegerType::SignednessSemantics::Unsigned)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint16>(
|
||||
{1, 2}, DT_UINT16,
|
||||
mlir::IntegerType::get(
|
||||
16, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
&context, 16, mlir::IntegerType::SignednessSemantics::Unsigned)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint32>(
|
||||
{1, 2}, DT_UINT32,
|
||||
mlir::IntegerType::get(
|
||||
32, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
&context, 32, mlir::IntegerType::SignednessSemantics::Unsigned)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint64>(
|
||||
{1, 2}, DT_UINT64,
|
||||
mlir::IntegerType::get(
|
||||
64, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
&context, 64, mlir::IntegerType::SignednessSemantics::Unsigned)));
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<float>>(
|
||||
{{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64,
|
||||
|
@ -150,7 +150,7 @@ StatusOr<FunctionDef> TFRDecomposeContext::ExpandNode(const NodeDef& node_def,
|
||||
mlir::Location loc = mlir::UnknownLoc::get(context);
|
||||
mlir::ModuleOp module = mlir::ModuleOp::create(loc);
|
||||
mlir::FunctionType func_type =
|
||||
mlir::FunctionType::get(input_tys, output_tys, context);
|
||||
mlir::FunctionType::get(context, input_tys, output_tys);
|
||||
llvm::StringRef func_name_str(func_name.data(), func_name.size());
|
||||
auto func = mlir::FuncOp::create(loc, func_name_str, func_type, {});
|
||||
module.push_back(func);
|
||||
|
@ -48,8 +48,8 @@ Operation* emitCallToPrint(Location loc, StringRef func_name, Value arg,
|
||||
|
||||
auto module = caller_func->getParentOfType<ModuleOp>();
|
||||
b->setInsertionPointToStart(module.getBody());
|
||||
auto func_type = FunctionType::get(arg.getType(), /*results=*/llvm::None,
|
||||
b->getContext());
|
||||
auto func_type = FunctionType::get(b->getContext(), arg.getType(),
|
||||
/*results=*/llvm::None);
|
||||
callee_func = b->create<FuncOp>(module.getLoc(), func_name, func_type);
|
||||
callee_func.setPrivate();
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
|
||||
llvm::SmallVector<Type, 4> args, rets;
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
|
||||
auto func_type = mlir::FunctionType::get(args, rets, context_);
|
||||
auto func_type = mlir::FunctionType::get(context_, args, rets);
|
||||
|
||||
string computation_name =
|
||||
computation.parent()->entry_computation() == &computation
|
||||
|
@ -144,7 +144,7 @@ static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
|
||||
static DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) {
|
||||
RankedTensorType ty =
|
||||
RankedTensorType::get(static_cast<int64_t>(attr.size()),
|
||||
IntegerType::get(64, attr.getContext()));
|
||||
IntegerType::get(attr.getContext(), 64));
|
||||
return DenseIntElementsAttr::get(ty, attr.getValue());
|
||||
}
|
||||
|
||||
@ -184,7 +184,7 @@ Type GetSumAccumulationType(Type input_type) {
|
||||
MLIRContext *ctx = input_type.getContext();
|
||||
if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
|
||||
if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
|
||||
return IntegerType::get(32, ctx);
|
||||
return IntegerType::get(ctx, 32);
|
||||
return input_type;
|
||||
}
|
||||
|
||||
@ -828,7 +828,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
|
||||
}
|
||||
}
|
||||
|
||||
auto element_type = IntegerType::get(64, input.getContext());
|
||||
auto element_type = IntegerType::get(input.getContext(), 64);
|
||||
return DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({shape[0]}, element_type), values);
|
||||
}
|
||||
@ -837,7 +837,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
|
||||
// in TensorFlow PadV2 op.
|
||||
static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
|
||||
auto length = tf_padding.getType().getShape()[0];
|
||||
auto element_type = IntegerType::get(64, tf_padding.getContext());
|
||||
auto element_type = IntegerType::get(tf_padding.getContext(), 64);
|
||||
return DenseIntElementsAttr::get<int64_t>(
|
||||
RankedTensorType::get({length}, element_type), 0);
|
||||
}
|
||||
@ -1837,7 +1837,7 @@ class ConvertFusedBatchNormGradBase
|
||||
Type feature_type = RankedTensorType::get(
|
||||
{GetDimSize(act_type, feature_dim)}, kernel_type);
|
||||
Type result_type = TupleType::get(
|
||||
{act.getType(), feature_type, feature_type}, rewriter.getContext());
|
||||
rewriter.getContext(), {act.getType(), feature_type, feature_type});
|
||||
|
||||
auto training_op = rewriter.create<BatchNormGradOp>(
|
||||
loc, result_type, act, scale, mean, var, grad, op.epsilon(),
|
||||
@ -1973,7 +1973,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
|
||||
// batch_mean, and batch_var.
|
||||
SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
|
||||
mean_var_type, mean_var_type};
|
||||
Type result_type = TupleType::get(operand_types, rewriter.getContext());
|
||||
Type result_type = TupleType::get(rewriter.getContext(), operand_types);
|
||||
|
||||
auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
|
||||
@ -4618,9 +4618,9 @@ class ConvertInfeedDequeueTupleOp
|
||||
// Emit infeed op.
|
||||
// The result type of infeed is a tuple(tuple(result types), token type).
|
||||
auto data_tuple_type =
|
||||
mlir::TupleType::get(result_types, rewriter.getContext());
|
||||
mlir::TupleType::get(rewriter.getContext(), result_types);
|
||||
auto data_and_token_type = mlir::TupleType::get(
|
||||
{data_tuple_type, token.getType()}, rewriter.getContext());
|
||||
rewriter.getContext(), {data_tuple_type, token.getType()});
|
||||
|
||||
auto data_and_token =
|
||||
rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,
|
||||
|
@ -281,7 +281,7 @@ Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
|
||||
/*type=*/builder.getI64IntegerAttr(3), builder.getContext());
|
||||
auto result_type = result.getType();
|
||||
auto recv_result_type =
|
||||
TupleType::get({result_type, token.getType()}, builder.getContext());
|
||||
TupleType::get(builder.getContext(), {result_type, token.getType()});
|
||||
auto recv =
|
||||
builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
|
||||
/*is_host_transfer=*/builder.getBoolAttr(true));
|
||||
@ -712,8 +712,8 @@ void UpdateFunctionType(OpBuilder& builder, FuncOp func, Block& func_body) {
|
||||
auto new_argument_types = llvm::to_vector<4>(func_body.getArgumentTypes());
|
||||
auto new_result_types =
|
||||
llvm::to_vector<4>(func_body.getTerminator()->getOperandTypes());
|
||||
func.setType(FunctionType::get(new_argument_types, new_result_types,
|
||||
builder.getContext()));
|
||||
func.setType(FunctionType::get(builder.getContext(), new_argument_types,
|
||||
new_result_types));
|
||||
}
|
||||
|
||||
// Replaces a function terminator `return` with another `return` that has an
|
||||
|
@ -108,7 +108,7 @@ Status EmitMlirFuncAndCall(
|
||||
// Create the function an call the emission callback.
|
||||
mlir::Location loc = mlir::UnknownLoc::get(context);
|
||||
auto function = mlir::FuncOp::create(
|
||||
loc, func_name, mlir::FunctionType::get(operand_types, {}, context));
|
||||
loc, func_name, mlir::FunctionType::get(context, operand_types, {}));
|
||||
function.addEntryBlock();
|
||||
mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc);
|
||||
mlir_module->push_back(function);
|
||||
|
@ -100,14 +100,14 @@ class ProcessType(ast.NodeVisitor):
|
||||
attr = getattr(value, node.attr)
|
||||
|
||||
if attr == core.Tensor:
|
||||
return tfp.UnrankedTensorType.get(tfp.IntegerType.get(32, self.prog.ctx))
|
||||
return tfp.UnrankedTensorType.get(tfp.IntegerType.get(self.prog.ctx, 32))
|
||||
return attr
|
||||
|
||||
def visit_Name(self, node):
|
||||
if node.id == 'int':
|
||||
return tfp.IntegerType.get(32, self.prog.ctx)
|
||||
return tfp.IntegerType.get(self.prog.ctx, 32)
|
||||
if node.id == 'bool':
|
||||
return tfp.IntegerType.get(1, self.prog.ctx)
|
||||
return tfp.IntegerType.get(self.prog.ctx, 1)
|
||||
if node.id in self.ctx.info.namespace:
|
||||
return self.ctx.info.namespace[node.id]
|
||||
|
||||
@ -203,7 +203,7 @@ class MLIRGen(ast.NodeVisitor):
|
||||
value = tfp.Tf_ConstOp.create(
|
||||
opb, opb.getUnknownLoc(),
|
||||
tfp.IntegerAttr.get(
|
||||
tfp.IntegerType.get(32, self.prog.ctx), node.value)).getResult(0)
|
||||
tfp.IntegerType.get(self.prog.ctx, 32), node.value)).getResult(0)
|
||||
return value
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
|
@ -85,7 +85,7 @@ class OrOp(object):
|
||||
def create(cls, opb, loc, values):
|
||||
state = mlir.OperationState(loc, "tfp.Or")
|
||||
state.addTypes(
|
||||
[UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
|
||||
[UnrankedTensorType.get(IntegerType.get(opb.getContext(), 1))])
|
||||
state.addOperands(values)
|
||||
return opb.createOperation(state)
|
||||
|
||||
@ -103,7 +103,7 @@ class AndOp(object):
|
||||
def create(cls, opb, loc, values):
|
||||
state = mlir.OperationState(loc, "tfp.And")
|
||||
state.addTypes(
|
||||
[UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
|
||||
[UnrankedTensorType.get(IntegerType.get(opb.getContext(), 1))])
|
||||
state.addOperands(values)
|
||||
return opb.createOperation(state)
|
||||
|
||||
|
@ -685,8 +685,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
)
|
||||
|
||||
# Check out LLVM and MLIR from llvm-project.
|
||||
LLVM_COMMIT = "511cfe9441955f20a8b93573fb9b62433b053550"
|
||||
LLVM_SHA256 = "57626cf2eb850c717b712e43774cad308f19cd9161db9ed286844ba8f42abd70"
|
||||
LLVM_COMMIT = "1b97cdf885d6455841280b8da858835e641ee941"
|
||||
LLVM_SHA256 = "80d5036ba734fcb700a5699e2f99e5a0de5808dde01a1df3c4fae04510bc8e23"
|
||||
LLVM_URLS = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||
|
Loading…
Reference in New Issue
Block a user