Copy over tf_device::LaunchOp attributes to tf_device::LaunchFuncOp in cluster outlining pass.
PiperOrigin-RevId: 266443023
This commit is contained in:
parent
222cdccfa6
commit
5bd2d2e766
@ -110,3 +110,20 @@ module {
|
|||||||
// CHECK: %[[TPU0_FUNC_A_OUTPUT:[0-9]*]] = "tf.A"()
|
// CHECK: %[[TPU0_FUNC_A_OUTPUT:[0-9]*]] = "tf.A"()
|
||||||
// CHECK: return %[[TPU0_FUNC_A_OUTPUT]]
|
// CHECK: return %[[TPU0_FUNC_A_OUTPUT]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Tests launch attributes are copied over to launch_func.
|
||||||
|
|
||||||
|
module {
|
||||||
|
// CHECK-LABEL: func @launch_attrs
|
||||||
|
func @launch_attrs() -> tensor<?xi32> {
|
||||||
|
%0 = "tf_device.launch"() ( {
|
||||||
|
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||||
|
"tf_device.return"(%1) : (tensor<?xi32>) -> ()
|
||||||
|
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<?xi32>
|
||||||
|
return %0 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: launch_attr = "launch_attr"
|
||||||
|
}
|
||||||
|
@ -35,6 +35,9 @@ namespace TFDevice {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kDeviceAttr[] = "device";
|
||||||
|
constexpr char kFuncAttr[] = "func";
|
||||||
|
|
||||||
struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> {
|
struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> {
|
||||||
void runOnModule() override;
|
void runOnModule() override;
|
||||||
};
|
};
|
||||||
@ -101,17 +104,19 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
|
|||||||
llvm::SetVector<Value*> live_ins;
|
llvm::SetVector<Value*> live_ins;
|
||||||
getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
|
getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
|
||||||
|
|
||||||
StringRef device = launch_op.getAttrOfType<StringAttr>("device").getValue();
|
StringRef device =
|
||||||
|
launch_op.getAttrOfType<StringAttr>(kDeviceAttr).getValue();
|
||||||
|
|
||||||
FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(),
|
FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(),
|
||||||
launch_op, module_manager, builder);
|
launch_op, module_manager, builder);
|
||||||
|
launch_op.setAttr(builder->getIdentifier(kFuncAttr),
|
||||||
|
builder->getSymbolRefAttr(outlined_func.getName()));
|
||||||
|
|
||||||
builder->setInsertionPoint(launch_op);
|
builder->setInsertionPoint(launch_op);
|
||||||
tf_device::LaunchFuncOp launch_func_op =
|
tf_device::LaunchFuncOp launch_func_op =
|
||||||
builder->create<tf_device::LaunchFuncOp>(
|
builder->create<tf_device::LaunchFuncOp>(
|
||||||
launch_op.getLoc(), outlined_func.getType().getResults(),
|
launch_op.getLoc(), outlined_func.getType().getResults(),
|
||||||
builder->getStringAttr(device),
|
live_ins.getArrayRef(), launch_op.getAttrs());
|
||||||
builder->getSymbolRefAttr(outlined_func.getName()),
|
|
||||||
live_ins.getArrayRef());
|
|
||||||
|
|
||||||
launch_op.replaceAllUsesWith(launch_func_op);
|
launch_op.replaceAllUsesWith(launch_func_op);
|
||||||
launch_op.erase();
|
launch_op.erase();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user