Copy over tf_device::LaunchOp attributes to tf_device::LaunchFuncOp in cluster outlining pass.

PiperOrigin-RevId: 266443023
This commit is contained in:
Andy Ly 2019-08-30 13:21:03 -07:00 committed by TensorFlower Gardener
parent 222cdccfa6
commit 5bd2d2e766
2 changed files with 26 additions and 4 deletions

View File

@ -110,3 +110,20 @@ module {
// CHECK: %[[TPU0_FUNC_A_OUTPUT:[0-9]*]] = "tf.A"()
// CHECK: return %[[TPU0_FUNC_A_OUTPUT]]
}
// -----
// Tests launch attributes are copied over to launch_func.
module {
// CHECK-LABEL: func @launch_attrs
func @launch_attrs() -> tensor<?xi32> {
%0 = "tf_device.launch"() ( {
%1 = "tf.A"() : () -> tensor<?xi32>
"tf_device.return"(%1) : (tensor<?xi32>) -> ()
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// CHECK: launch_attr = "launch_attr"
}

View File

@ -35,6 +35,9 @@ namespace TFDevice {
namespace {
constexpr char kDeviceAttr[] = "device";
constexpr char kFuncAttr[] = "func";
struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> {
void runOnModule() override;
};
@ -101,17 +104,19 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
llvm::SetVector<Value*> live_ins;
getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
StringRef device = launch_op.getAttrOfType<StringAttr>("device").getValue();
StringRef device =
launch_op.getAttrOfType<StringAttr>(kDeviceAttr).getValue();
FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(),
launch_op, module_manager, builder);
launch_op.setAttr(builder->getIdentifier(kFuncAttr),
builder->getSymbolRefAttr(outlined_func.getName()));
builder->setInsertionPoint(launch_op);
tf_device::LaunchFuncOp launch_func_op =
builder->create<tf_device::LaunchFuncOp>(
launch_op.getLoc(), outlined_func.getType().getResults(),
builder->getStringAttr(device),
builder->getSymbolRefAttr(outlined_func.getName()),
live_ins.getArrayRef());
live_ins.getArrayRef(), launch_op.getAttrs());
launch_op.replaceAllUsesWith(launch_func_op);
launch_op.erase();