diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir index f8797678231..30272b443a1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir @@ -110,3 +110,20 @@ module { // CHECK: %[[TPU0_FUNC_A_OUTPUT:[0-9]*]] = "tf.A"() // CHECK: return %[[TPU0_FUNC_A_OUTPUT]] } + +// ----- + +// Tests launch attributes are copied over to launch_func. + +module { + // CHECK-LABEL: func @launch_attrs + func @launch_attrs() -> tensor<?xi32> { + %0 = "tf_device.launch"() ( { + %1 = "tf.A"() : () -> tensor<?xi32> + "tf_device.return"(%1) : (tensor<?xi32>) -> () + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<?xi32> + return %0 : tensor<?xi32> + } + +// CHECK: launch_attr = "launch_attr" +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index 496e07e473c..535dbbeec32 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -35,6 +35,9 @@ namespace TFDevice { namespace { +constexpr char kDeviceAttr[] = "device"; +constexpr char kFuncAttr[] = "func"; + struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> { void runOnModule() override; }; @@ -101,17 +104,19 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager, llvm::SetVector<Value*> live_ins; getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins); - StringRef device = launch_op.getAttrOfType<StringAttr>("device").getValue(); + StringRef device = + launch_op.getAttrOfType<StringAttr>(kDeviceAttr).getValue(); FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(), launch_op, module_manager, builder); + launch_op.setAttr(builder->getIdentifier(kFuncAttr), + builder->getSymbolRefAttr(outlined_func.getName())); + builder->setInsertionPoint(launch_op); tf_device::LaunchFuncOp launch_func_op = builder->create<tf_device::LaunchFuncOp>( launch_op.getLoc(), outlined_func.getType().getResults(), - builder->getStringAttr(device), - builder->getSymbolRefAttr(outlined_func.getName()), - live_ins.getArrayRef()); + live_ins.getArrayRef(), launch_op.getAttrs()); launch_op.replaceAllUsesWith(launch_func_op); launch_op.erase();