Copy over tf_device::LaunchOp attributes to tf_device::LaunchFuncOp in cluster outlining pass.
PiperOrigin-RevId: 266443023
This commit is contained in:
parent
222cdccfa6
commit
5bd2d2e766
@ -110,3 +110,20 @@ module {
|
||||
// CHECK: %[[TPU0_FUNC_A_OUTPUT:[0-9]*]] = "tf.A"()
|
||||
// CHECK: return %[[TPU0_FUNC_A_OUTPUT]]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests launch attributes are copied over to launch_func.
|
||||
|
||||
module {
|
||||
// CHECK-LABEL: func @launch_attrs
|
||||
func @launch_attrs() -> tensor<?xi32> {
|
||||
%0 = "tf_device.launch"() ( {
|
||||
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||
"tf_device.return"(%1) : (tensor<?xi32>) -> ()
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK: launch_attr = "launch_attr"
|
||||
}
|
||||
|
@ -35,6 +35,9 @@ namespace TFDevice {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
constexpr char kFuncAttr[] = "func";
|
||||
|
||||
struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> {
|
||||
void runOnModule() override;
|
||||
};
|
||||
@ -101,17 +104,19 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
|
||||
llvm::SetVector<Value*> live_ins;
|
||||
getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
|
||||
|
||||
StringRef device = launch_op.getAttrOfType<StringAttr>("device").getValue();
|
||||
StringRef device =
|
||||
launch_op.getAttrOfType<StringAttr>(kDeviceAttr).getValue();
|
||||
|
||||
FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(),
|
||||
launch_op, module_manager, builder);
|
||||
launch_op.setAttr(builder->getIdentifier(kFuncAttr),
|
||||
builder->getSymbolRefAttr(outlined_func.getName()));
|
||||
|
||||
builder->setInsertionPoint(launch_op);
|
||||
tf_device::LaunchFuncOp launch_func_op =
|
||||
builder->create<tf_device::LaunchFuncOp>(
|
||||
launch_op.getLoc(), outlined_func.getType().getResults(),
|
||||
builder->getStringAttr(device),
|
||||
builder->getSymbolRefAttr(outlined_func.getName()),
|
||||
live_ins.getArrayRef());
|
||||
live_ins.getArrayRef(), launch_op.getAttrs());
|
||||
|
||||
launch_op.replaceAllUsesWith(launch_func_op);
|
||||
launch_op.erase();
|
||||
|
Loading…
Reference in New Issue
Block a user