diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir
index f8797678231..30272b443a1 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir
@@ -110,3 +110,20 @@ module {
 // CHECK: %[[TPU0_FUNC_A_OUTPUT:[0-9]*]] = "tf.A"()
 // CHECK: return %[[TPU0_FUNC_A_OUTPUT]]
 }
+
+// -----
+
+// Tests launch attributes are copied over to launch_func.
+
+module {
+  // CHECK-LABEL: func @launch_attrs
+  func @launch_attrs() -> tensor<?xi32> {
+    %0 = "tf_device.launch"() ( {
+      %1 = "tf.A"() : () -> tensor<?xi32>
+      "tf_device.return"(%1) : (tensor<?xi32>) -> ()
+    }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<?xi32>
+    return %0 : tensor<?xi32>
+  }
+
+// CHECK: launch_attr = "launch_attr"
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
index 496e07e473c..535dbbeec32 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
@@ -35,6 +35,9 @@ namespace TFDevice {
 
 namespace {
 
+constexpr char kDeviceAttr[] = "device";
+constexpr char kFuncAttr[] = "func";
+
 struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> {
   void runOnModule() override;
 };
@@ -101,17 +104,19 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
   llvm::SetVector<Value*> live_ins;
   getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
 
-  StringRef device = launch_op.getAttrOfType<StringAttr>("device").getValue();
+  StringRef device =
+      launch_op.getAttrOfType<StringAttr>(kDeviceAttr).getValue();
 
   FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(),
                                        launch_op, module_manager, builder);
+  launch_op.setAttr(builder->getIdentifier(kFuncAttr),
+                    builder->getSymbolRefAttr(outlined_func.getName()));
+
   builder->setInsertionPoint(launch_op);
   tf_device::LaunchFuncOp launch_func_op =
       builder->create<tf_device::LaunchFuncOp>(
           launch_op.getLoc(), outlined_func.getType().getResults(),
-          builder->getStringAttr(device),
-          builder->getSymbolRefAttr(outlined_func.getName()),
-          live_ins.getArrayRef());
+          live_ins.getArrayRef(), launch_op.getAttrs());
 
   launch_op.replaceAllUsesWith(launch_func_op);
   launch_op.erase();