diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
index 60c944837f5..9761e6abb0a 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc
@@ -202,7 +202,7 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
     MLIRContext* context, Optional<Location> location, ValueRange operands,
     DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
-  Type element_type = IntegerType::get(1, context);
+  Type element_type = IntegerType::get(context, 1);
   return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
                                                     attributes, element_type,
                                                     inferedReturnShapes);
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index a03bfa148d8..cec1ad7e6ee 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -621,7 +621,7 @@ OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
 static LogicalResult Verify(TupleOp op) {
   SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
                                        op.operand_type_end()};
-  auto expectedType = TupleType::get(operandTypes, op.getContext());
+  auto expectedType = TupleType::get(op.getContext(), operandTypes);
   if (op.getType() != expectedType) {
     return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
                                         op.getType(), expectedType));
@@ -1967,7 +1967,7 @@ LogicalResult ReplicaIdOp::inferReturnTypes(
     MLIRContext* context, Optional<Location>, ValueRange operands,
     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
   inferredReturnTypes.push_back(RankedTensorType::get(
-      /*shape=*/{}, IntegerType::get(32, IntegerType::Unsigned, context)));
+      /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
   return success();
 }
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc
index 63bbd44e135..454ec180db5 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc
@@ -145,7 +145,7 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
 
     auto int_shape_type = RankedTensorType::get(
         output_type.getShape(),
-        IntegerType::get(bitwidth, rewriter.getContext()));
+        IntegerType::get(rewriter.getContext(), bitwidth));
     auto loc = op.getLoc();
     auto integer_const = rewriter.create<mlir::ConstantOp>(
         loc, DenseIntElementsAttr::get(int_shape_type, values));
diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.cc b/tensorflow/compiler/mlir/lite/quantization/device_target.cc
index 25bc1d24241..09dae973872 100644
--- a/tensorflow/compiler/mlir/lite/quantization/device_target.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/device_target.cc
@@ -37,10 +37,10 @@ constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
 
 DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
   f32_ = FloatType::getF32(ctx_);
-  i8_ = IntegerType::get(k8Bits, ctx_);
+  i8_ = IntegerType::get(ctx_, k8Bits);
   i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
   i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
-  i32_ = IntegerType::get(k32Bits, ctx_);
+  i32_ = IntegerType::get(ctx_, k32Bits);
   i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
   i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
   any_ = AnyQuantizedType();
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index 7318fccc906..802c84a5043 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -212,7 +212,7 @@ LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
   }
 
   *attr_i32 = IntegerAttr::get(
-      IntegerType::get(/*width=*/32, attr.getContext()), value);
+      IntegerType::get(attr.getContext(), /*width=*/32), value);
   return success();
 }
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index e0b93163b3b..43f983482cd 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -547,8 +547,8 @@ struct ConvertTensorListResize
     Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
                                size_diff.getType(), size.getType()};
     Type branch_result_type[] = {result_type};
-    auto func_type = FunctionType::get(branch_args_type, branch_result_type,
-                                       rewriter.getContext());
+    auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
+                                       branch_result_type);
 
     // Constructs `then_branch`, which is executed when `if_cond` evaluates to
     // true.
@@ -775,8 +775,8 @@ LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
     // Change `func`'s argument type to `unranked_argument_types`. If it
     // return types contain a `DT_VARIANT`, change it to the unranked type
     // derived from the corresponding argument.
-    func.setType(FunctionType::get(updated_argument_types, updated_result_types,
-                                   op.getContext()));
+    func.setType(FunctionType::get(op.getContext(), updated_argument_types,
+                                   updated_result_types));
 
     // Change the argument type for the first block.
     llvm::for_each(func.getArguments(), [&](BlockArgument &arg) {
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index a1198be933e..a053ff0342e 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -243,7 +243,7 @@ DenseElementsAttr GetShape(Value output_val) {
   return mlir::DenseElementsAttr::get(
       RankedTensorType::get(
           {static_cast<int>(shape.size())},
-          mlir::IntegerType::get(32, output_val.getContext())),
+          mlir::IntegerType::get(output_val.getContext(), 32)),
       llvm::makeArrayRef(shape));
 }
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
index a599df9c2f0..ce2ce2a2cce 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
@@ -50,7 +50,7 @@ void UpdateFuncType(FuncOp func) {
   if (llvm::makeArrayRef(return_types) == func_type.getResults()) return;
 
   auto updated_type =
-      FunctionType::get(func_type.getInputs(), return_types, func.getContext());
+      FunctionType::get(func.getContext(), func_type.getInputs(), return_types);
   func.setType(updated_type);
 }
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
index 91d5379fb8e..83d4ac34010 100644
--- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc
@@ -134,12 +134,12 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
                                  bool passthru_extra_args) {
     FunctionType type;
     if (passthru_extra_args) {
-      type = FunctionType::get(types, types, &getContext());
+      type = FunctionType::get(&getContext(), types, types);
     } else {
       SmallVector<Type, 4> result_types;
       auto operands = region.front().getTerminator()->getOperandTypes();
       result_types.append(operands.begin(), operands.end());
-      type = FunctionType::get(types, result_types, &getContext());
+      type = FunctionType::get(&getContext(), types, result_types);
     }
 
     auto outlined_func = builder.create<FuncOp>(while_op.getLoc(), name, type);
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
index 1a503675f45..090551fea53 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
@@ -382,8 +382,8 @@ void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
   auto input_types = fused_func_op_.getType().getInputs();
   auto output_type = mlir::RankedTensorType::get(
       output_shape, input_.getType().cast<RankedTensorType>().getElementType());
-  fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
-                                                 fused_func_op_.getContext()));
+  fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(),
+                                                 input_types, output_type));
 }
 
 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
@@ -820,8 +820,8 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
   }
 
   // Update function signatures.
-  func_op.setType(mlir::FunctionType::get(func_op.getType().getInputs(),
-                                          output_types, func_op.getContext()));
+  func_op.setType(mlir::FunctionType::get(
+      func_op.getContext(), func_op.getType().getInputs(), output_types));
 
   builder->create<mlir::ReturnOp>(func_op.getLoc(), outputs);
   return success();
diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
index 8ea311275d7..7ff50a07cd3 100644
--- a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
+++ b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
@@ -33,7 +33,7 @@ void init_types(py::module& m) {
       .def("getF64", &mlir::FloatType::getF64);
 
   py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
-      .def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
+      .def("get", py::overload_cast<mlir::MLIRContext*, unsigned>(
                       &mlir::IntegerType::get));
 
   py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")
diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
index cc727edc1ac..e4d3dc600c7 100644
--- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
+++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
@@ -668,7 +668,7 @@ Status MlirFunctionContext::Finalize(OutputList* outputs,
 
   auto arg_types = body.getArgumentTypes();
   auto result_types = body.getTerminator()->getOperandTypes();
-  func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
+  func_.setType(FunctionType::get(func_.getContext(), arg_types, result_types));
   *f = new MlirFunction(std::move(context_), std::move(module_), func_);
   return Status::OK();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index 3b428bd4fa7..9b8c1864d46 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -1310,7 +1310,7 @@ LogicalResult ConcatOffsetOp::fold(ArrayRef<Attribute> operands,
   results.reserve(shapes.size());
   SmallVector<int32_t, 4> cumulative_sum(num_dims, 0);
   RankedTensorType offset_type =
-      RankedTensorType::get({num_dims}, IntegerType::get(32, getContext()));
+      RankedTensorType::get({num_dims}, IntegerType::get(getContext(), 32));
   for (DenseIntElementsAttr shape : shapes) {
     results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum));
     cumulative_sum[concat_dim] += shape.getValue<int32_t>(concat_dim);
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index e291e373650..8b717ec5320 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -898,7 +898,7 @@ static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
     dimensions.push_back(APInt(out_width, shape[i]));
 
   auto result_type = RankedTensorType::get(
-      {rank}, IntegerType::get(out_width, input_ty.getContext()));
+      {rank}, IntegerType::get(input_ty.getContext(), out_width));
   return DenseElementsAttr::get(result_type, dimensions);
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc
index fd2b18a0492..0b21b86029f 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc
@@ -155,19 +155,19 @@ Type TensorFlowRefType::RemoveRef() {
   if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
   if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
   if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
-  if (isa<BoolRefType>()) return mlir::IntegerType::get(1, ctx);
-  if (isa<Int8RefType>()) return mlir::IntegerType::get(8, ctx);
-  if (isa<Int16RefType>()) return mlir::IntegerType::get(16, ctx);
-  if (isa<Int32RefType>()) return mlir::IntegerType::get(32, ctx);
-  if (isa<Int64RefType>()) return mlir::IntegerType::get(64, ctx);
+  if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
+  if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
+  if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
+  if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
+  if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
   if (isa<Uint8RefType>())
-    return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
+    return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
   if (isa<Uint16RefType>())
-    return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
+    return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
   if (isa<Uint32RefType>())
-    return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
+    return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
   if (isa<Uint64RefType>())
-    return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
+    return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
   if (isa<Complex64RefType>())
     return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
   if (isa<Complex128RefType>())
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
index 04134b37295..854058efc6e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
@@ -58,8 +58,8 @@ FuncOp BuildFunction(llvm::ArrayRef<Value> live_ins,
   operand_types.reserve(live_ins.size());
   for (Value v : live_ins) operand_types.emplace_back(v.getType());
 
-  auto func_type = FunctionType::get(operand_types, cluster_op.getResultTypes(),
-                                     builder->getContext());
+  auto func_type =
+      builder->getFunctionType(operand_types, cluster_op.getResultTypes());
 
   // TODO(lyandy): Define better name for outlined function. Potentially some
   // name can be added during cluster formation.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc
index 249cd7a5319..6a786ea3f12 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc
@@ -193,7 +193,7 @@ void CreateFunctions(ModuleOp module_op,
     std::replace(func_name.begin(), func_name.end(), '/', '_');
 
     FunctionType func_type =
-        FunctionType::get(input_types, result_types, context);
+        FunctionType::get(context, input_types, result_types);
     Location loc = metadata.ops.front()->getLoc();
     FuncOp func_op = FuncOp::create(loc, func_name, func_type);
     // Sets the device attribute for every input and every result of the
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc
index 98f7efecb99..1da5e366e53 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc
@@ -53,7 +53,7 @@ Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
   values.reserve(rank);
   for (int i = 0; i < rank; ++i) values.push_back(APInt(bitwidth, r1[i]));
   auto result_type = RankedTensorType::get(
-      {rank}, IntegerType::get(bitwidth, builder.getContext()));
+      {rank}, IntegerType::get(builder.getContext(), bitwidth));
   return builder.create<TF::ConstOp>(
       loc, DenseElementsAttr::get(result_type, values));
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc
index 93173cdaa0c..e47f2fd6037 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc
@@ -100,7 +100,7 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
     for (Value operand : island_op.GetYield().getOperands())
       func_result_types.push_back(operand.getType());
     FunctionType func_type =
-        FunctionType::get(func_operand_types, func_result_types, ctx);
+        FunctionType::get(ctx, func_operand_types, func_result_types);
 
     // Create the outlined function
     SmallString<32> name = kOutlinedFuncPrefix;
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc
index ea34696a195..12bc3b3a9d9 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc
@@ -213,9 +213,9 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) {
     }
 
     // Update the function type.
-    func.setType(mlir::FunctionType::get(func.getArgumentTypes(),
-                                         func.getType().getResults(),
-                                         module.getContext()));
+    func.setType(mlir::FunctionType::get(module.getContext(),
+                                         func.getArgumentTypes(),
+                                         func.getType().getResults()));
   }
   return success();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
index 9f36c2dd943..0a2391297f4 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
@@ -180,8 +180,8 @@ mlir::LogicalResult PromoteVarHandlesToArguments(
   }
 
   if (!var_handle_shared_names->empty())
-    function.setType(FunctionType::get(func_arg_types, func_type.getResults(),
-                                       function.getContext()));
+    function.setType(FunctionType::get(function.getContext(), func_arg_types,
+                                       func_type.getResults()));
 
   return success();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
index 4d93fcad1a0..0c5aa723707 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
@@ -121,7 +121,7 @@ void ExtractSingleBlockRegion(Region& region, StringRef name,
   if (extern_values_passthrough)
     for (auto input : extern_values) return_types.push_back(input.getType());
 
-  auto type = FunctionType::get(input_types, return_types, region.getContext());
+  auto type = FunctionType::get(region.getContext(), input_types, return_types);
 
   // Create new function and extract region body into the function.
   auto outlined_func = builder.create<FuncOp>(loc, name, type);
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
index bddec8a8e4a..9dd6671d1b6 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
@@ -785,9 +785,9 @@ void RemoveUnusedResourceArgumentsAndForwardedRetvals(
     }
   }
   func_op.eraseArguments(indices_to_erase);
-  func_op.setType(FunctionType::get(
-      new_types, llvm::to_vector<4>(return_op->getOperandTypes()),
-      func_op.getContext()));
+  func_op.setType(
+      FunctionType::get(func_op.getContext(), new_types,
+                        llvm::to_vector<4>(return_op->getOperandTypes())));
 }
 
 // Lifts reads/writes of resource arguments from func_op and changes its
@@ -841,10 +841,9 @@ LogicalResult LiftArgRetResourcesForFunction(
     assign_variable_op.erase();
   }
 
-  func_op.setType(
-      FunctionType::get(func_op.front().getArgumentTypes(),
-                        func_op.front().getTerminator()->getOperandTypes(),
-                        func_op.getContext()));
+  func_op.setType(FunctionType::get(
+      func_op.getContext(), func_op.front().getArgumentTypes(),
+      func_op.front().getTerminator()->getOperandTypes()));
 
   return success();
 }
@@ -1153,9 +1152,9 @@ LogicalResult HandlePartitionedCallOpCallee(
   auto new_return =
       builder.create<ReturnOp>(old_return->getLoc(), old_and_new_retvals);
   old_return->erase();
-  callee.setType(FunctionType::get(
-      callee.getType().getInputs(),
-      llvm::to_vector<4>(new_return.getOperandTypes()), callee.getContext()));
+  callee.setType(
+      FunctionType::get(callee.getContext(), callee.getType().getInputs(),
+                        llvm::to_vector<4>(new_return.getOperandTypes())));
   return success();
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
index f774ccea855..6a149d3a85a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
@@ -184,9 +184,9 @@ void EliminateUnusedResultsForIfCase(Operation *op, ArrayRef<FuncOp> branches) {
   // Patch up function types (with less number of return values and potentially
   // less number of arguments)
   for (FuncOp func : cloned_branches) {
-    func.setType(FunctionType::get(
-        func.front().getArgumentTypes(),
-        func.front().getTerminator()->getOperandTypes(), func.getContext()));
+    func.setType(
+        FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
+                          func.front().getTerminator()->getOperandTypes()));
   }
 
   EliminateUnusedResults(op);
@@ -232,9 +232,9 @@ void EliminateUnusedResultsForWhile(TF::WhileOp op) {
 
   // Patch up branch function types.
   for (FuncOp func : {cloned_cond, cloned_body}) {
-    func.setType(FunctionType::get(
-        func.front().getArgumentTypes(),
-        func.front().getTerminator()->getOperandTypes(), func.getContext()));
+    func.setType(
+        FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
+                          func.front().getTerminator()->getOperandTypes()));
   }
   EliminateUnusedResults(op, &can_eliminate);
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 26b1f4cccb4..f198ff32b0e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -1150,8 +1150,8 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
     }
 
     FunctionType func_type = func.getType();
-    func.setType(FunctionType::get(input_types, func_type.getResults(),
-                                   func.getContext()));
+    func.setType(FunctionType::get(func.getContext(), input_types,
+                                   func_type.getResults()));
 
     auto res =
         PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration);
@@ -1493,8 +1493,8 @@ void ShapeInference::InferShapeForFunctionReturnType(FuncOp func) {
   }
 
   DCOMMENT("Updating function type");
-  func.setType(FunctionType::get(
-      func.getArgumentTypes(), return_op.getOperandTypes(), func.getContext()));
+  func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(),
+                                 return_op.getOperandTypes()));
 
   if (changed) EnqueueCallers(func);
 }
@@ -1611,8 +1611,8 @@ LogicalResult InferShapeForFunction(FuncOp func,
     return failure();
 
   context.InferShapeForFunctionReturnType(func);
-  func.setType(FunctionType::get(new_arg_types, func.getType().getResults(),
-                                 func.getContext()));
+  func.setType(FunctionType::get(func.getContext(), new_arg_types,
+                                 func.getType().getResults()));
 
   return success();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc
index 9e68f9a5411..33f9d344f0f 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc
@@ -137,9 +137,9 @@ void ModifyFunctionSignature(
   if (handle_new_size_vars) {
     handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
   }
-  func.setType(FunctionType::get(
-      new_input_types, func.front().getTerminator()->getOperandTypes(),
-      func.getContext()));
+  func.setType(
+      FunctionType::get(func.getContext(), new_input_types,
+                        func.front().getTerminator()->getOperandTypes()));
 }
 
 // Contains cached information for decomposed callee functions for (stateful)
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc
index ee21bb84537..ab8c86d7876 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc
@@ -460,10 +460,9 @@ LogicalResult HandleTensorArrayScatterV3Op(
 void UpdateFuncType(FuncOp func) {
   llvm::SmallVector<Type, 8> arg_types;
   for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
-  func.setType(FunctionType::get(
-      arg_types,
-      llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
-      func.getContext()));
+  func.setType(
+      FunctionType::get(func.getContext(), arg_types,
+                        func.front().getTerminator()->getOperandTypes()));
 }
 
 // Finds the accessed gradient sources for each tensor array argument.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc
index 5a2af8dae7f..d2a465926d1 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc
@@ -71,9 +71,9 @@ struct TensorListOpsDecompositionPass
 void UpdateFuncType(FuncOp func) {
   llvm::SmallVector<Type, 8> arg_types;
   for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
-  func.setType(FunctionType::get(
-      arg_types, func.front().getTerminator()->getOperandTypes(),
-      func.getContext()));
+  func.setType(
+      FunctionType::get(func.getContext(), arg_types,
+                        func.front().getTerminator()->getOperandTypes()));
 }
 
 // Holds the size value of a tensor list and whether the size is statically
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc
index 812cafabcd3..cb69d777631 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc
@@ -118,8 +118,8 @@ void TPUResourceReadForWrite::runOnOperation() {
     for (Value read_operand : read_operands)
       block.addArgument(read_operand.getType());
 
-    func.setType(FunctionType::get(block.getArgumentTypes(),
-                                   func.getCallableResults(), &getContext()));
+    func.setType(FunctionType::get(&getContext(), block.getArgumentTypes(),
+                                   func.getCallableResults()));
     cluster_func.erase();
   }
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc
index 80b5e2c6f54..ccc31d2b530 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc
@@ -117,7 +117,7 @@ struct TPUSpaceToDepthPass
 void UpdateFuncType(FuncOp func) {
   auto arg_types = func.front().getArgumentTypes();
   auto result_types = func.front().getTerminator()->getOperandTypes();
-  func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
+  func.setType(FunctionType::get(func.getContext(), arg_types, result_types));
 }
 
 void HandleFuncOp(Operation* op) {
@@ -196,7 +196,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) {
   MLIRContext* context = conv2d.getContext();
   SmallVector<int64_t, 4> values = {1, 1, 1, 1};
   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
-    return IntegerAttr::get(IntegerType::get(64, context), v);
+    return IntegerAttr::get(IntegerType::get(context, 64), v);
   });
   // TODO(b/157276506): change type of strides to DenseElementsAttr
   auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
@@ -351,7 +351,7 @@ void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
   MLIRContext* context = backprop.getContext();
   SmallVector<int64_t, 4> values = {1, 1, 1, 1};
   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
-    return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
+    return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
   });
   auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
 
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index f9b4e0cd56b..56c2c04e8da 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -1454,9 +1454,9 @@ Status ImporterBase::Convert(
       all_equal = false;
     }
     if (!all_equal) {
-      function.setType(mlir::FunctionType::get(func_type.getInputs(),
-                                               graph.getResultTypes(),
-                                               function.getContext()));
+      function.setType(mlir::FunctionType::get(function.getContext(),
+                                               func_type.getInputs(),
+                                               graph.getResultTypes()));
     }
   }
 
@@ -2906,8 +2906,8 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
       }
       new_input_types.push_back(arg.getType());
     }
-    func.setType(mlir::FunctionType::get(
-        new_input_types, func.getType().getResults(), module.getContext()));
+    func.setType(mlir::FunctionType::get(module.getContext(), new_input_types,
+                                         func.getType().getResults()));
   }
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 61388d3f7f9..6b09455d9c0 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -569,9 +569,9 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
     for (mlir::BlockArgument& arg : main_fn.getArguments())
       updated_argument_types.push_back(arg.getType());
 
-    main_fn.setType(mlir::FunctionType::get(updated_argument_types,
-                                            main_fn.getType().getResults(),
-                                            main_fn.getContext()));
+    main_fn.setType(mlir::FunctionType::get(main_fn.getContext(),
+                                            updated_argument_types,
+                                            main_fn.getType().getResults()));
   }
 
   for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
index 0c7afb53850..578bbab64f2 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
@@ -136,30 +136,30 @@ TEST_F(ConvertTensorTest, Simple) {
       {1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context)));
 
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<int8>(
-      {1, -1}, DT_INT8, mlir::IntegerType::get(8, &context)));
+      {1, -1}, DT_INT8, mlir::IntegerType::get(&context, 8)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<int16>(
-      {1, -1}, DT_INT16, mlir::IntegerType::get(16, &context)));
+      {1, -1}, DT_INT16, mlir::IntegerType::get(&context, 16)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<int32>(
-      {1, -1}, DT_INT32, mlir::IntegerType::get(32, &context)));
+      {1, -1}, DT_INT32, mlir::IntegerType::get(&context, 32)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<int64>(
-      {1, -1}, DT_INT64, mlir::IntegerType::get(64, &context)));
+      {1, -1}, DT_INT64, mlir::IntegerType::get(&context, 64)));
 
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint8>(
       {1, 2}, DT_UINT8,
       mlir::IntegerType::get(
-          8, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+          &context, 8, mlir::IntegerType::SignednessSemantics::Unsigned)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint16>(
       {1, 2}, DT_UINT16,
       mlir::IntegerType::get(
-          16, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+          &context, 16, mlir::IntegerType::SignednessSemantics::Unsigned)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint32>(
       {1, 2}, DT_UINT32,
       mlir::IntegerType::get(
-          32, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+          &context, 32, mlir::IntegerType::SignednessSemantics::Unsigned)));
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint64>(
       {1, 2}, DT_UINT64,
       mlir::IntegerType::get(
-          64, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+          &context, 64, mlir::IntegerType::SignednessSemantics::Unsigned)));
 
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<float>>(
       {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64,
diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
index 36cb8c06196..a5edc9fb87d 100644
--- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
+++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc
@@ -150,7 +150,7 @@ StatusOr<FunctionDef> TFRDecomposeContext::ExpandNode(const NodeDef& node_def,
   mlir::Location loc = mlir::UnknownLoc::get(context);
   mlir::ModuleOp module = mlir::ModuleOp::create(loc);
   mlir::FunctionType func_type =
-      mlir::FunctionType::get(input_tys, output_tys, context);
+      mlir::FunctionType::get(context, input_tys, output_tys);
   llvm::StringRef func_name_str(func_name.data(), func_name.size());
   auto func = mlir::FuncOp::create(loc, func_name_str, func_type, {});
   module.push_back(func);
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc
index 3b29d83e2ff..4b859628ced 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc
@@ -48,8 +48,8 @@ Operation* emitCallToPrint(Location loc, StringRef func_name, Value arg,
 
     auto module = caller_func->getParentOfType<ModuleOp>();
     b->setInsertionPointToStart(module.getBody());
-    auto func_type = FunctionType::get(arg.getType(), /*results=*/llvm::None,
-                                       b->getContext());
+    auto func_type = FunctionType::get(b->getContext(), arg.getType(),
+                                       /*results=*/llvm::None);
     callee_func = b->create<FuncOp>(module.getLoc(), func_name, func_type);
     callee_func.setPrivate();
   }
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index 341132d96f5..501f2a4b859 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -127,7 +127,7 @@ StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
   llvm::SmallVector<Type, 4> args, rets;
   TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
   TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
-  auto func_type = mlir::FunctionType::get(args, rets, context_);
+  auto func_type = mlir::FunctionType::get(context_, args, rets);
 
   string computation_name =
       computation.parent()->entry_computation() == &computation
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 6eb844cbc89..0e754cea3e0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -144,7 +144,7 @@ static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
 static DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) {
   RankedTensorType ty =
       RankedTensorType::get(static_cast<int64_t>(attr.size()),
-                            IntegerType::get(64, attr.getContext()));
+                            IntegerType::get(attr.getContext(), 64));
   return DenseIntElementsAttr::get(ty, attr.getValue());
 }
 
@@ -184,7 +184,7 @@ Type GetSumAccumulationType(Type input_type) {
   MLIRContext *ctx = input_type.getContext();
   if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
   if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
-    return IntegerType::get(32, ctx);
+    return IntegerType::get(ctx, 32);
   return input_type;
 }
 
@@ -828,7 +828,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
     }
   }
 
-  auto element_type = IntegerType::get(64, input.getContext());
+  auto element_type = IntegerType::get(input.getContext(), 64);
   return DenseIntElementsAttr::get(
       RankedTensorType::get({shape[0]}, element_type), values);
 }
@@ -837,7 +837,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
 // in TensorFlow PadV2 op.
 static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
   auto length = tf_padding.getType().getShape()[0];
-  auto element_type = IntegerType::get(64, tf_padding.getContext());
+  auto element_type = IntegerType::get(tf_padding.getContext(), 64);
   return DenseIntElementsAttr::get<int64_t>(
       RankedTensorType::get({length}, element_type), 0);
 }
@@ -1837,7 +1837,7 @@ class ConvertFusedBatchNormGradBase
       Type feature_type = RankedTensorType::get(
           {GetDimSize(act_type, feature_dim)}, kernel_type);
       Type result_type = TupleType::get(
-          {act.getType(), feature_type, feature_type}, rewriter.getContext());
+          rewriter.getContext(), {act.getType(), feature_type, feature_type});
 
       auto training_op = rewriter.create<BatchNormGradOp>(
           loc, result_type, act, scale, mean, var, grad, op.epsilon(),
@@ -1973,7 +1973,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
       // batch_mean, and batch_var.
       SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
                                             mean_var_type, mean_var_type};
-      Type result_type = TupleType::get(operand_types, rewriter.getContext());
+      Type result_type = TupleType::get(rewriter.getContext(), operand_types);
 
       auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
           op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
@@ -4618,9 +4618,9 @@ class ConvertInfeedDequeueTupleOp
     // Emit infeed op.
     // The result type of infeed is a tuple(tuple(result types), token type).
     auto data_tuple_type =
-        mlir::TupleType::get(result_types, rewriter.getContext());
+        mlir::TupleType::get(rewriter.getContext(), result_types);
     auto data_and_token_type = mlir::TupleType::get(
-        {data_tuple_type, token.getType()}, rewriter.getContext());
+        rewriter.getContext(), {data_tuple_type, token.getType()});
 
     auto data_and_token =
         rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc
index 890378bc3e0..ef14408ef7a 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc
@@ -281,7 +281,7 @@ Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
       /*type=*/builder.getI64IntegerAttr(3), builder.getContext());
   auto result_type = result.getType();
   auto recv_result_type =
-      TupleType::get({result_type, token.getType()}, builder.getContext());
+      TupleType::get(builder.getContext(), {result_type, token.getType()});
   auto recv =
       builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
                              /*is_host_transfer=*/builder.getBoolAttr(true));
@@ -712,8 +712,8 @@ void UpdateFunctionType(OpBuilder& builder, FuncOp func, Block& func_body) {
   auto new_argument_types = llvm::to_vector<4>(func_body.getArgumentTypes());
   auto new_result_types =
       llvm::to_vector<4>(func_body.getTerminator()->getOperandTypes());
-  func.setType(FunctionType::get(new_argument_types, new_result_types,
-                                 builder.getContext()));
+  func.setType(FunctionType::get(builder.getContext(), new_argument_types,
+                                 new_result_types));
 }
 
 // Replaces a function terminator `return` with another `return` that has an
diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
index 69849386884..ee0ffd6fbd1 100644
--- a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
@@ -108,7 +108,7 @@ Status EmitMlirFuncAndCall(
   // Create the function an call the emission callback.
   mlir::Location loc = mlir::UnknownLoc::get(context);
   auto function = mlir::FuncOp::create(
-      loc, func_name, mlir::FunctionType::get(operand_types, {}, context));
+      loc, func_name, mlir::FunctionType::get(context, operand_types, {}));
   function.addEntryBlock();
   mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc);
   mlir_module->push_back(function);
diff --git a/tensorflow/python/tf_program/mlir_gen.py b/tensorflow/python/tf_program/mlir_gen.py
index 8395848a53a..3e41084c6c5 100644
--- a/tensorflow/python/tf_program/mlir_gen.py
+++ b/tensorflow/python/tf_program/mlir_gen.py
@@ -100,14 +100,14 @@ class ProcessType(ast.NodeVisitor):
     attr = getattr(value, node.attr)
 
     if attr == core.Tensor:
-      return tfp.UnrankedTensorType.get(tfp.IntegerType.get(32, self.prog.ctx))
+      return tfp.UnrankedTensorType.get(tfp.IntegerType.get(self.prog.ctx, 32))
     return attr
 
   def visit_Name(self, node):
     if node.id == 'int':
-      return tfp.IntegerType.get(32, self.prog.ctx)
+      return tfp.IntegerType.get(self.prog.ctx, 32)
     if node.id == 'bool':
-      return tfp.IntegerType.get(1, self.prog.ctx)
+      return tfp.IntegerType.get(self.prog.ctx, 1)
     if node.id in self.ctx.info.namespace:
       return self.ctx.info.namespace[node.id]
 
@@ -203,7 +203,7 @@ class MLIRGen(ast.NodeVisitor):
       value = tfp.Tf_ConstOp.create(
           opb, opb.getUnknownLoc(),
           tfp.IntegerAttr.get(
-              tfp.IntegerType.get(32, self.prog.ctx), node.value)).getResult(0)
+              tfp.IntegerType.get(self.prog.ctx, 32), node.value)).getResult(0)
     return value
 
   def visit_FunctionDef(self, node):
diff --git a/tensorflow/python/tf_program/pywrap_tfd.py b/tensorflow/python/tf_program/pywrap_tfd.py
index a7a30b71f4e..af198f64180 100644
--- a/tensorflow/python/tf_program/pywrap_tfd.py
+++ b/tensorflow/python/tf_program/pywrap_tfd.py
@@ -85,7 +85,7 @@ class OrOp(object):
   def create(cls, opb, loc, values):
     state = mlir.OperationState(loc, "tfp.Or")
     state.addTypes(
-        [UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
+        [UnrankedTensorType.get(IntegerType.get(opb.getContext(), 1))])
     state.addOperands(values)
     return opb.createOperation(state)
 
@@ -103,7 +103,7 @@ class AndOp(object):
   def create(cls, opb, loc, values):
     state = mlir.OperationState(loc, "tfp.And")
     state.addTypes(
-        [UnrankedTensorType.get(IntegerType.get(1, opb.getContext()))])
+        [UnrankedTensorType.get(IntegerType.get(opb.getContext(), 1))])
     state.addOperands(values)
     return opb.createOperation(state)
 
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 909a8d1f9c6..a16b8bf3181 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -685,8 +685,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
     )
 
     # Check out LLVM and MLIR from llvm-project.
-    LLVM_COMMIT = "511cfe9441955f20a8b93573fb9b62433b053550"
-    LLVM_SHA256 = "57626cf2eb850c717b712e43774cad308f19cd9161db9ed286844ba8f42abd70"
+    LLVM_COMMIT = "1b97cdf885d6455841280b8da858835e641ee941"
+    LLVM_SHA256 = "80d5036ba734fcb700a5699e2f99e5a0de5808dde01a1df3c4fae04510bc8e23"
     LLVM_URLS = [
         "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
         "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),