Updates LLVM usage to match
[1b97cdf885d6](https://github.com/llvm/llvm-project/commit/1b97cdf885d6)

PiperOrigin-RevId: 348587513
Change-Id: I853d197b33c5df08c00c99ddc8cf8b2681bed708
This commit is contained in:
A. Unique TensorFlower 2020-12-21 23:48:24 -08:00 committed by TensorFlower Gardener
parent 682d130cf6
commit 94155a3934
42 changed files with 118 additions and 120 deletions

View File

@ -202,7 +202,7 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
Type element_type = IntegerType::get(1, context);
Type element_type = IntegerType::get(context, 1);
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
attributes, element_type,
inferedReturnShapes);

View File

@ -621,7 +621,7 @@ OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
static LogicalResult Verify(TupleOp op) {
SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
op.operand_type_end()};
auto expectedType = TupleType::get(operandTypes, op.getContext());
auto expectedType = TupleType::get(op.getContext(), operandTypes);
if (op.getType() != expectedType) {
return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
op.getType(), expectedType));
@ -1967,7 +1967,7 @@ LogicalResult ReplicaIdOp::inferReturnTypes(
MLIRContext* context, Optional<Location>, ValueRange operands,
DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(RankedTensorType::get(
/*shape=*/{}, IntegerType::get(32, IntegerType::Unsigned, context)));
/*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
return success();
}

View File

@ -145,7 +145,7 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
auto int_shape_type = RankedTensorType::get(
output_type.getShape(),
IntegerType::get(bitwidth, rewriter.getContext()));
IntegerType::get(rewriter.getContext(), bitwidth));
auto loc = op.getLoc();
auto integer_const = rewriter.create<mlir::ConstantOp>(
loc, DenseIntElementsAttr::get(int_shape_type, values));

View File

@ -37,10 +37,10 @@ constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
f32_ = FloatType::getF32(ctx_);
i8_ = IntegerType::get(k8Bits, ctx_);
i8_ = IntegerType::get(ctx_, k8Bits);
i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
i32_ = IntegerType::get(k32Bits, ctx_);
i32_ = IntegerType::get(ctx_, k32Bits);
i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
any_ = AnyQuantizedType();

View File

@ -212,7 +212,7 @@ LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
}
*attr_i32 = IntegerAttr::get(
IntegerType::get(/*width=*/32, attr.getContext()), value);
IntegerType::get(attr.getContext(), /*width=*/32), value);
return success();
}

View File

@ -547,8 +547,8 @@ struct ConvertTensorListResize
Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
size_diff.getType(), size.getType()};
Type branch_result_type[] = {result_type};
auto func_type = FunctionType::get(branch_args_type, branch_result_type,
rewriter.getContext());
auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
branch_result_type);
// Constructs `then_branch`, which is executed when `if_cond` evaluates to
// true.
@ -775,8 +775,8 @@ LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
// Change `func`'s argument type to `unranked_argument_types`. If it
// return types contain a `DT_VARIANT`, change it to the unranked type
// derived from the corresponding argument.
func.setType(FunctionType::get(updated_argument_types, updated_result_types,
op.getContext()));
func.setType(FunctionType::get(op.getContext(), updated_argument_types,
updated_result_types));
// Change the argument type for the first block.
llvm::for_each(func.getArguments(), [&](BlockArgument &arg) {

View File

@ -243,7 +243,7 @@ DenseElementsAttr GetShape(Value output_val) {
return mlir::DenseElementsAttr::get(
RankedTensorType::get(
{static_cast<int>(shape.size())},
mlir::IntegerType::get(32, output_val.getContext())),
mlir::IntegerType::get(output_val.getContext(), 32)),
llvm::makeArrayRef(shape));
}

View File

@ -50,7 +50,7 @@ void UpdateFuncType(FuncOp func) {
if (llvm::makeArrayRef(return_types) == func_type.getResults()) return;
auto updated_type =
FunctionType::get(func_type.getInputs(), return_types, func.getContext());
FunctionType::get(func.getContext(), func_type.getInputs(), return_types);
func.setType(updated_type);
}

View File

@ -134,12 +134,12 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
bool passthru_extra_args) {
FunctionType type;
if (passthru_extra_args) {
type = FunctionType::get(types, types, &getContext());
type = FunctionType::get(&getContext(), types, types);
} else {
SmallVector<Type, 4> result_types;
auto operands = region.front().getTerminator()->getOperandTypes();
result_types.append(operands.begin(), operands.end());
type = FunctionType::get(types, result_types, &getContext());
type = FunctionType::get(&getContext(), types, result_types);
}
auto outlined_func = builder.create<FuncOp>(while_op.getLoc(), name, type);

View File

@ -382,8 +382,8 @@ void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
auto input_types = fused_func_op_.getType().getInputs();
auto output_type = mlir::RankedTensorType::get(
output_shape, input_.getType().cast<RankedTensorType>().getElementType());
fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
fused_func_op_.getContext()));
fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(),
input_types, output_type));
}
LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
@ -820,8 +820,8 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
}
// Update function signatures.
func_op.setType(mlir::FunctionType::get(func_op.getType().getInputs(),
output_types, func_op.getContext()));
func_op.setType(mlir::FunctionType::get(
func_op.getContext(), func_op.getType().getInputs(), output_types));
builder->create<mlir::ReturnOp>(func_op.getLoc(), outputs);
return success();

View File

@ -33,7 +33,7 @@ void init_types(py::module& m) {
.def("getF64", &mlir::FloatType::getF64);
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
.def("get", py::overload_cast<mlir::MLIRContext*, unsigned>(
&mlir::IntegerType::get));
py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")

View File

@ -668,7 +668,7 @@ Status MlirFunctionContext::Finalize(OutputList* outputs,
auto arg_types = body.getArgumentTypes();
auto result_types = body.getTerminator()->getOperandTypes();
func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
func_.setType(FunctionType::get(func_.getContext(), arg_types, result_types));
*f = new MlirFunction(std::move(context_), std::move(module_), func_);
return Status::OK();
}

View File

@ -1310,7 +1310,7 @@ LogicalResult ConcatOffsetOp::fold(ArrayRef<Attribute> operands,
results.reserve(shapes.size());
SmallVector<int32_t, 4> cumulative_sum(num_dims, 0);
RankedTensorType offset_type =
RankedTensorType::get({num_dims}, IntegerType::get(32, getContext()));
RankedTensorType::get({num_dims}, IntegerType::get(getContext(), 32));
for (DenseIntElementsAttr shape : shapes) {
results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum));
cumulative_sum[concat_dim] += shape.getValue<int32_t>(concat_dim);

View File

@ -898,7 +898,7 @@ static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
dimensions.push_back(APInt(out_width, shape[i]));
auto result_type = RankedTensorType::get(
{rank}, IntegerType::get(out_width, input_ty.getContext()));
{rank}, IntegerType::get(input_ty.getContext(), out_width));
return DenseElementsAttr::get(result_type, dimensions);
}

View File

@ -155,19 +155,19 @@ Type TensorFlowRefType::RemoveRef() {
if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
if (isa<BoolRefType>()) return mlir::IntegerType::get(1, ctx);
if (isa<Int8RefType>()) return mlir::IntegerType::get(8, ctx);
if (isa<Int16RefType>()) return mlir::IntegerType::get(16, ctx);
if (isa<Int32RefType>()) return mlir::IntegerType::get(32, ctx);
if (isa<Int64RefType>()) return mlir::IntegerType::get(64, ctx);
if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
if (isa<Uint8RefType>())
return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
if (isa<Uint16RefType>())
return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
if (isa<Uint32RefType>())
return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
if (isa<Uint64RefType>())
return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
if (isa<Complex64RefType>())
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
if (isa<Complex128RefType>())

View File

@ -58,8 +58,8 @@ FuncOp BuildFunction(llvm::ArrayRef<Value> live_ins,
operand_types.reserve(live_ins.size());
for (Value v : live_ins) operand_types.emplace_back(v.getType());
auto func_type = FunctionType::get(operand_types, cluster_op.getResultTypes(),
builder->getContext());
auto func_type =
builder->getFunctionType(operand_types, cluster_op.getResultTypes());
// TODO(lyandy): Define better name for outlined function. Potentially some
// name can be added during cluster formation.

View File

@ -193,7 +193,7 @@ void CreateFunctions(ModuleOp module_op,
std::replace(func_name.begin(), func_name.end(), '/', '_');
FunctionType func_type =
FunctionType::get(input_types, result_types, context);
FunctionType::get(context, input_types, result_types);
Location loc = metadata.ops.front()->getLoc();
FuncOp func_op = FuncOp::create(loc, func_name, func_type);
// Sets the device attribute for every input and every result of the

View File

@ -53,7 +53,7 @@ Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
values.reserve(rank);
for (int i = 0; i < rank; ++i) values.push_back(APInt(bitwidth, r1[i]));
auto result_type = RankedTensorType::get(
{rank}, IntegerType::get(bitwidth, builder.getContext()));
{rank}, IntegerType::get(builder.getContext(), bitwidth));
return builder.create<TF::ConstOp>(
loc, DenseElementsAttr::get(result_type, values));
}

View File

@ -100,7 +100,7 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
for (Value operand : island_op.GetYield().getOperands())
func_result_types.push_back(operand.getType());
FunctionType func_type =
FunctionType::get(func_operand_types, func_result_types, ctx);
FunctionType::get(ctx, func_operand_types, func_result_types);
// Create the outlined function
SmallString<32> name = kOutlinedFuncPrefix;

View File

@ -213,9 +213,9 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) {
}
// Update the function type.
func.setType(mlir::FunctionType::get(func.getArgumentTypes(),
func.getType().getResults(),
module.getContext()));
func.setType(mlir::FunctionType::get(module.getContext(),
func.getArgumentTypes(),
func.getType().getResults()));
}
return success();
}

View File

@ -180,8 +180,8 @@ mlir::LogicalResult PromoteVarHandlesToArguments(
}
if (!var_handle_shared_names->empty())
function.setType(FunctionType::get(func_arg_types, func_type.getResults(),
function.getContext()));
function.setType(FunctionType::get(function.getContext(), func_arg_types,
func_type.getResults()));
return success();
}

View File

@ -121,7 +121,7 @@ void ExtractSingleBlockRegion(Region& region, StringRef name,
if (extern_values_passthrough)
for (auto input : extern_values) return_types.push_back(input.getType());
auto type = FunctionType::get(input_types, return_types, region.getContext());
auto type = FunctionType::get(region.getContext(), input_types, return_types);
// Create new function and extract region body into the function.
auto outlined_func = builder.create<FuncOp>(loc, name, type);

View File

@ -785,9 +785,9 @@ void RemoveUnusedResourceArgumentsAndForwardedRetvals(
}
}
func_op.eraseArguments(indices_to_erase);
func_op.setType(FunctionType::get(
new_types, llvm::to_vector<4>(return_op->getOperandTypes()),
func_op.getContext()));
func_op.setType(
FunctionType::get(func_op.getContext(), new_types,
llvm::to_vector<4>(return_op->getOperandTypes())));
}
// Lifts reads/writes of resource arguments from func_op and changes its
@ -841,10 +841,9 @@ LogicalResult LiftArgRetResourcesForFunction(
assign_variable_op.erase();
}
func_op.setType(
FunctionType::get(func_op.front().getArgumentTypes(),
func_op.front().getTerminator()->getOperandTypes(),
func_op.getContext()));
func_op.setType(FunctionType::get(
func_op.getContext(), func_op.front().getArgumentTypes(),
func_op.front().getTerminator()->getOperandTypes()));
return success();
}
@ -1153,9 +1152,9 @@ LogicalResult HandlePartitionedCallOpCallee(
auto new_return =
builder.create<ReturnOp>(old_return->getLoc(), old_and_new_retvals);
old_return->erase();
callee.setType(FunctionType::get(
callee.getType().getInputs(),
llvm::to_vector<4>(new_return.getOperandTypes()), callee.getContext()));
callee.setType(
FunctionType::get(callee.getContext(), callee.getType().getInputs(),
llvm::to_vector<4>(new_return.getOperandTypes())));
return success();
}

View File

@ -184,9 +184,9 @@ void EliminateUnusedResultsForIfCase(Operation *op, ArrayRef<FuncOp> branches) {
// Patch up function types (with less number of return values and potentially
// less number of arguments)
for (FuncOp func : cloned_branches) {
func.setType(FunctionType::get(
func.front().getArgumentTypes(),
func.front().getTerminator()->getOperandTypes(), func.getContext()));
func.setType(
FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
func.front().getTerminator()->getOperandTypes()));
}
EliminateUnusedResults(op);
@ -232,9 +232,9 @@ void EliminateUnusedResultsForWhile(TF::WhileOp op) {
// Patch up branch function types.
for (FuncOp func : {cloned_cond, cloned_body}) {
func.setType(FunctionType::get(
func.front().getArgumentTypes(),
func.front().getTerminator()->getOperandTypes(), func.getContext()));
func.setType(
FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
func.front().getTerminator()->getOperandTypes()));
}
EliminateUnusedResults(op, &can_eliminate);
}

View File

@ -1150,8 +1150,8 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
}
FunctionType func_type = func.getType();
func.setType(FunctionType::get(input_types, func_type.getResults(),
func.getContext()));
func.setType(FunctionType::get(func.getContext(), input_types,
func_type.getResults()));
auto res =
PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration);
@ -1493,8 +1493,8 @@ void ShapeInference::InferShapeForFunctionReturnType(FuncOp func) {
}
DCOMMENT("Updating function type");
func.setType(FunctionType::get(
func.getArgumentTypes(), return_op.getOperandTypes(), func.getContext()));
func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(),
return_op.getOperandTypes()));
if (changed) EnqueueCallers(func);
}
@ -1611,8 +1611,8 @@ LogicalResult InferShapeForFunction(FuncOp func,
return failure();
context.InferShapeForFunctionReturnType(func);
func.setType(FunctionType::get(new_arg_types, func.getType().getResults(),
func.getContext()));
func.setType(FunctionType::get(func.getContext(), new_arg_types,
func.getType().getResults()));
return success();
}

View File

@ -137,9 +137,9 @@ void ModifyFunctionSignature(
if (handle_new_size_vars) {
handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
}
func.setType(FunctionType::get(
new_input_types, func.front().getTerminator()->getOperandTypes(),
func.getContext()));
func.setType(
FunctionType::get(func.getContext(), new_input_types,
func.front().getTerminator()->getOperandTypes()));
}
// Contains cached information for decomposed callee functions for (stateful)

View File

@ -460,10 +460,9 @@ LogicalResult HandleTensorArrayScatterV3Op(
void UpdateFuncType(FuncOp func) {
llvm::SmallVector<Type, 8> arg_types;
for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
func.setType(FunctionType::get(
arg_types,
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
func.getContext()));
func.setType(
FunctionType::get(func.getContext(), arg_types,
func.front().getTerminator()->getOperandTypes()));
}
// Finds the accessed gradient sources for each tensor array argument.

View File

@ -71,9 +71,9 @@ struct TensorListOpsDecompositionPass
void UpdateFuncType(FuncOp func) {
llvm::SmallVector<Type, 8> arg_types;
for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
func.setType(FunctionType::get(
arg_types, func.front().getTerminator()->getOperandTypes(),
func.getContext()));
func.setType(
FunctionType::get(func.getContext(), arg_types,
func.front().getTerminator()->getOperandTypes()));
}
// Holds the size value of a tensor list and whether the size is statically

View File

@ -118,8 +118,8 @@ void TPUResourceReadForWrite::runOnOperation() {
for (Value read_operand : read_operands)
block.addArgument(read_operand.getType());
func.setType(FunctionType::get(block.getArgumentTypes(),
func.getCallableResults(), &getContext()));
func.setType(FunctionType::get(&getContext(), block.getArgumentTypes(),
func.getCallableResults()));
cluster_func.erase();
}
}

View File

@ -117,7 +117,7 @@ struct TPUSpaceToDepthPass
void UpdateFuncType(FuncOp func) {
auto arg_types = func.front().getArgumentTypes();
auto result_types = func.front().getTerminator()->getOperandTypes();
func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
func.setType(FunctionType::get(func.getContext(), arg_types, result_types));
}
void HandleFuncOp(Operation* op) {
@ -196,7 +196,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) {
MLIRContext* context = conv2d.getContext();
SmallVector<int64_t, 4> values = {1, 1, 1, 1};
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
return IntegerAttr::get(IntegerType::get(64, context), v);
return IntegerAttr::get(IntegerType::get(context, 64), v);
});
// TODO(b/157276506): change type of strides to DenseElementsAttr
auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
@ -351,7 +351,7 @@ void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
MLIRContext* context = backprop.getContext();
SmallVector<int64_t, 4> values = {1, 1, 1, 1};
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
});
auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);

View File

@ -1454,9 +1454,9 @@ Status ImporterBase::Convert(
all_equal = false;
}
if (!all_equal) {
function.setType(mlir::FunctionType::get(func_type.getInputs(),
graph.getResultTypes(),
function.getContext()));
function.setType(mlir::FunctionType::get(function.getContext(),
func_type.getInputs(),
graph.getResultTypes()));
}
}
@ -2906,8 +2906,8 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
}
new_input_types.push_back(arg.getType());
}
func.setType(mlir::FunctionType::get(
new_input_types, func.getType().getResults(), module.getContext()));
func.setType(mlir::FunctionType::get(module.getContext(), new_input_types,
func.getType().getResults()));
}
}

View File

@ -569,9 +569,9 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
for (mlir::BlockArgument& arg : main_fn.getArguments())
updated_argument_types.push_back(arg.getType());
main_fn.setType(mlir::FunctionType::get(updated_argument_types,
main_fn.getType().getResults(),
main_fn.getContext()));
main_fn.setType(mlir::FunctionType::get(main_fn.getContext(),
updated_argument_types,
main_fn.getType().getResults()));
}
for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);

View File

@ -136,30 +136,30 @@ TEST_F(ConvertTensorTest, Simple) {
{1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int8>(
{1, -1}, DT_INT8, mlir::IntegerType::get(8, &context)));
{1, -1}, DT_INT8, mlir::IntegerType::get(&context, 8)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int16>(
{1, -1}, DT_INT16, mlir::IntegerType::get(16, &context)));
{1, -1}, DT_INT16, mlir::IntegerType::get(&context, 16)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int32>(
{1, -1}, DT_INT32, mlir::IntegerType::get(32, &context)));
{1, -1}, DT_INT32, mlir::IntegerType::get(&context, 32)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int64>(
{1, -1}, DT_INT64, mlir::IntegerType::get(64, &context)));
{1, -1}, DT_INT64, mlir::IntegerType::get(&context, 64)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint8>(
{1, 2}, DT_UINT8,
mlir::IntegerType::get(
8, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
&context, 8, mlir::IntegerType::SignednessSemantics::Unsigned)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint16>(
{1, 2}, DT_UINT16,
mlir::IntegerType::get(
16, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
&context, 16, mlir::IntegerType::SignednessSemantics::Unsigned)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint32>(
{1, 2}, DT_UINT32,
mlir::IntegerType::get(
32, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
&context, 32, mlir::IntegerType::SignednessSemantics::Unsigned)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint64>(
{1, 2}, DT_UINT64,
mlir::IntegerType::get(
64, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
&context, 64, mlir::IntegerType::SignednessSemantics::Unsigned)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<float>>(
{{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64,

View File

@ -150,7 +150,7 @@ StatusOr<FunctionDef> TFRDecomposeContext::ExpandNode(const NodeDef& node_def,
mlir::Location loc = mlir::UnknownLoc::get(context);
mlir::ModuleOp module = mlir::ModuleOp::create(loc);
mlir::FunctionType func_type =
mlir::FunctionType::get(input_tys, output_tys, context);
mlir::FunctionType::get(context, input_tys, output_tys);
llvm::StringRef func_name_str(func_name.data(), func_name.size());
auto func = mlir::FuncOp::create(loc, func_name_str, func_type, {});
module.push_back(func);

View File

@ -48,8 +48,8 @@ Operation* emitCallToPrint(Location loc, StringRef func_name, Value arg,
auto module = caller_func->getParentOfType<ModuleOp>();
b->setInsertionPointToStart(module.getBody());
auto func_type = FunctionType::get(arg.getType(), /*results=*/llvm::None,
b->getContext());
auto func_type = FunctionType::get(b->getContext(), arg.getType(),
/*results=*/llvm::None);
callee_func = b->create<FuncOp>(module.getLoc(), func_name, func_type);
callee_func.setPrivate();
}

View File

@ -127,7 +127,7 @@ StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
llvm::SmallVector<Type, 4> args, rets;
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
auto func_type = mlir::FunctionType::get(args, rets, context_);
auto func_type = mlir::FunctionType::get(context_, args, rets);
string computation_name =
computation.parent()->entry_computation() == &computation

View File

@ -144,7 +144,7 @@ static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
static DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) {
RankedTensorType ty =
RankedTensorType::get(static_cast<int64_t>(attr.size()),
IntegerType::get(64, attr.getContext()));
IntegerType::get(attr.getContext(), 64));
return DenseIntElementsAttr::get(ty, attr.getValue());
}
@ -184,7 +184,7 @@ Type GetSumAccumulationType(Type input_type) {
MLIRContext *ctx = input_type.getContext();
if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
return IntegerType::get(32, ctx);
return IntegerType::get(ctx, 32);
return input_type;
}
@ -828,7 +828,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
}
}
auto element_type = IntegerType::get(64, input.getContext());
auto element_type = IntegerType::get(input.getContext(), 64);
return DenseIntElementsAttr::get(
RankedTensorType::get({shape[0]}, element_type), values);
}
@ -837,7 +837,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
// in TensorFlow PadV2 op.
static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
auto length = tf_padding.getType().getShape()[0];
auto element_type = IntegerType::get(64, tf_padding.getContext());
auto element_type = IntegerType::get(tf_padding.getContext(), 64);
return DenseIntElementsAttr::get<int64_t>(
RankedTensorType::get({length}, element_type), 0);
}
@ -1837,7 +1837,7 @@ class ConvertFusedBatchNormGradBase
Type feature_type = RankedTensorType::get(
{GetDimSize(act_type, feature_dim)}, kernel_type);
Type result_type = TupleType::get(
{act.getType(), feature_type, feature_type}, rewriter.getContext());
rewriter.getContext(), {act.getType(), feature_type, feature_type});
auto training_op = rewriter.create<BatchNormGradOp>(
loc, result_type, act, scale, mean, var, grad, op.epsilon(),
@ -1973,7 +1973,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
// batch_mean, and batch_var.
SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
mean_var_type, mean_var_type};
Type result_type = TupleType::get(operand_types, rewriter.getContext());
Type result_type = TupleType::get(rewriter.getContext(), operand_types);
auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
@ -4618,9 +4618,9 @@ class ConvertInfeedDequeueTupleOp
// Emit infeed op.
// The result type of infeed is a tuple(tuple(result types), token type).
auto data_tuple_type =
mlir::TupleType::get(result_types, rewriter.getContext());
mlir::TupleType::get(rewriter.getContext(), result_types);
auto data_and_token_type = mlir::TupleType::get(
{data_tuple_type, token.getType()}, rewriter.getContext());
rewriter.getContext(), {data_tuple_type, token.getType()});
auto data_and_token =
rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,

View File

@ -281,7 +281,7 @@ Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
/*type=*/builder.getI64IntegerAttr(3), builder.getContext());
auto result_type = result.getType();
auto recv_result_type =
TupleType::get({result_type, token.getType()}, builder.getContext());
TupleType::get(builder.getContext(), {result_type, token.getType()});
auto recv =
builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
/*is_host_transfer=*/builder.getBoolAttr(true));
@ -712,8 +712,8 @@ void UpdateFunctionType(OpBuilder& builder, FuncOp func, Block& func_body) {
auto new_argument_types = llvm::to_vector<4>(func_body.getArgumentTypes());
auto new_result_types =
llvm::to_vector<4>(func_body.getTerminator()->getOperandTypes());
func.setType(FunctionType::get(new_argument_types, new_result_types,
builder.getContext()));
func.setType(FunctionType::get(builder.getContext(), new_argument_types,
new_result_types));
}
// Replaces a function terminator `return` with another `return` that has an

View File

@ -108,7 +108,7 @@ Status EmitMlirFuncAndCall(
// Create the function an call the emission callback.
mlir::Location loc = mlir::UnknownLoc::get(context);
auto function = mlir::FuncOp::create(
loc, func_name, mlir::FunctionType::get(operand_types, {}, context));
loc, func_name, mlir::FunctionType::get(context, operand_types, {}));
function.addEntryBlock();
mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc);
mlir_module->push_back(function);

View File

@ -100,14 +100,14 @@ class ProcessType(ast.NodeVisitor):
attr = getattr(value, node.attr)
if attr == core.Tensor:
return tfp.UnrankedTensorType.get(tfp.IntegerType.get(32, self.prog.ctx))
return tfp.UnrankedTensorType.get(tfp.IntegerType.get(self.prog.ctx, 32))
return attr
def visit_Name(self, node):
if node.id == 'int':
return tfp.IntegerType.get(32, self.prog.ctx)
return tfp.IntegerType.get(self.prog.ctx, 32)
if node.id == 'bool':
return tfp.IntegerType.get(1, self.prog.ctx)
return tfp.IntegerType.get(self.prog.ctx, 1)
if node.id in self.ctx.info.namespace:
return self.ctx.info.namespace[node.id]
@ -203,7 +203,7 @@ class MLIRGen(ast.NodeVisitor):
value = tfp.Tf_ConstOp.create(
opb, opb.getUnknownLoc(),
tfp.IntegerAttr.get(
tfp.IntegerType.get(32, self.prog.ctx), node.value)).getResult(0)
tfp.IntegerType.get(self.prog.ctx, 32), node.value)).getResult(0)
return value
def visit_FunctionDef(self, node):

View File

@ -85,7 +85,7 @@ class OrOp(object):
def create(cls, opb, loc, values):
state = mlir.OperationState(loc, "tfp.Or")
state.addTypes(
[UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
[UnrankedTensorType.get(IntegerType.get(opb.getContext(), 1))])
state.addOperands(values)
return opb.createOperation(state)
@ -103,7 +103,7 @@ class AndOp(object):
def create(cls, opb, loc, values):
state = mlir.OperationState(loc, "tfp.And")
state.addTypes(
[UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
[UnrankedTensorType.get(IntegerType.get(opb.getContext(), 1))])
state.addOperands(values)
return opb.createOperation(state)

View File

@ -685,8 +685,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
)
# Check out LLVM and MLIR from llvm-project.
LLVM_COMMIT = "511cfe9441955f20a8b93573fb9b62433b053550"
LLVM_SHA256 = "57626cf2eb850c717b712e43774cad308f19cd9161db9ed286844ba8f42abd70"
LLVM_COMMIT = "1b97cdf885d6455841280b8da858835e641ee941"
LLVM_SHA256 = "80d5036ba734fcb700a5699e2f99e5a0de5808dde01a1df3c4fae04510bc8e23"
LLVM_URLS = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),