Replace using OperationState and checking op name for determining if an op is tf._TPUCompileMlir or tf.TPUCompileSucceededAssert.
More TPU specific ops were added to ODS, so they can be used directly instead. PiperOrigin-RevId: 299149173 Change-Id: I45c7eace58f4a36043c9b29de6adb9efb4f574d9
This commit is contained in:
parent
cbbbe1dc02
commit
d5504fbc78
@ -570,7 +570,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[ARG_2_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_2]])
|
||||
%0 = "tf_device.launch_func"(%arg0, %arg1, %arg2, %arg3) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32>
|
||||
// CHECK: "tf._TPUCompileMlir"(%[[ARG_0_SHAPE]], %[[ARG_2_SHAPE]])
|
||||
// CHECK-SAME: NumDynamicShapes = 2
|
||||
|
||||
return %0: tensor<8xi32>
|
||||
}
|
||||
@ -594,7 +593,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
@ -640,7 +638,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
@ -712,7 +709,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
@ -762,7 +758,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
@ -808,7 +803,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
@ -862,7 +856,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
@ -910,7 +903,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
@ -925,7 +917,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
|
||||
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
@ -968,7 +959,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
@ -983,7 +973,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
|
||||
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
@ -1022,7 +1011,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
|
@ -244,8 +244,7 @@ void TPUDynamicLayoutPass::runOnFunction() {
|
||||
if (!compile || !compile->getResult(1).hasOneUse()) return;
|
||||
auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
|
||||
if (!compile_launch || !compile_launch.WrapsSingleOp() ||
|
||||
compile_launch.GetBody().front().getName().getStringRef() !=
|
||||
"tf._TPUCompileMlir")
|
||||
!llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
|
||||
return;
|
||||
executes_and_compiles.emplace_back(execute, compile_launch);
|
||||
});
|
||||
|
@ -361,9 +361,6 @@ Operation* BuildCompileOp(
|
||||
int num_cores_per_replica, llvm::StringRef compilation_device,
|
||||
llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
|
||||
OpBuilder* builder) {
|
||||
// TODO(b/139377366): Use tf_tpu.compile build method when it is defined.
|
||||
OperationState compile_op_state(launch_func.getLoc(), "tf._TPUCompileMlir");
|
||||
|
||||
// Set metadata from attributes.
|
||||
tensorflow::tpu::TPUCompileMetadataProto metadata;
|
||||
if (failed(SetMetadataProtoFromLaunchFuncOp(
|
||||
@ -377,9 +374,6 @@ Operation* BuildCompileOp(
|
||||
else
|
||||
metadata.SerializeToString(&txt_metadata);
|
||||
|
||||
compile_op_state.addAttribute("metadata",
|
||||
builder->getStringAttr(txt_metadata));
|
||||
|
||||
// Build a shape op for each input to launch_func.
|
||||
// TODO(b/139377366): When shape inference is ready, we can use compile time
|
||||
// shape inference to get inputs that have static shapes and only use shape
|
||||
@ -399,36 +393,22 @@ Operation* BuildCompileOp(
|
||||
operand_and_idx.value());
|
||||
compile_op_operands.emplace_back(shape_op.getResult());
|
||||
}
|
||||
compile_op_state.addOperands(compile_op_operands);
|
||||
compile_op_state.addAttribute(
|
||||
"NumDynamicShapes",
|
||||
builder->getI64IntegerAttr(compile_op_operands.size()));
|
||||
|
||||
FlatSymbolRefAttr func_attr =
|
||||
launch_func.getAttrOfType<FlatSymbolRefAttr>("func");
|
||||
if (!func_attr) {
|
||||
launch_func.emitOpError("does not have `func` attribute");
|
||||
return nullptr;
|
||||
}
|
||||
FlatSymbolRefAttr func_attr = launch_func.funcAttr();
|
||||
FuncOp func = launch_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||
func_attr.getValue());
|
||||
|
||||
std::string txt_module;
|
||||
if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
|
||||
compile_op_state.addAttribute("mlir_module",
|
||||
builder->getStringAttr(txt_module));
|
||||
|
||||
// Result #0 is a string indicating whether compilation is successful or not.
|
||||
compile_op_state.addTypes(
|
||||
RankedTensorType::get({}, builder->getType<TF::StringType>()));
|
||||
auto result_type =
|
||||
RankedTensorType::get({}, builder->getType<TF::StringType>());
|
||||
|
||||
// Result #1 is key to look up executable binary in compilation cache.
|
||||
compile_op_state.addTypes(
|
||||
RankedTensorType::get({}, builder->getType<TF::StringType>()));
|
||||
auto compile_op = builder->create<TF::_TPUCompileMlirOp>(
|
||||
launch_func.getLoc(), /*compilation_status=*/result_type,
|
||||
/*program=*/result_type, compile_op_operands, txt_module, txt_metadata);
|
||||
|
||||
Operation* compile_op = builder->createOperation(compile_op_state);
|
||||
|
||||
return WrapOpInLaunch(builder, compile_op->getLoc(), compile_op,
|
||||
return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op,
|
||||
compilation_device);
|
||||
}
|
||||
|
||||
@ -548,10 +528,8 @@ tf_device::LaunchOp AssignDevicesToReplicatedExecute(
|
||||
void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
|
||||
llvm::StringRef compilation_device,
|
||||
OpBuilder* builder) {
|
||||
OperationState assert_op_state(compile_op->getLoc(),
|
||||
"tf.TPUCompileSucceededAssert");
|
||||
assert_op_state.addOperands(compile_op->getResult(0));
|
||||
Operation* assert_op = builder->createOperation(assert_op_state);
|
||||
auto assert_op = builder->create<TF::TPUCompileSucceededAssertOp>(
|
||||
compile_op->getLoc(), compile_op->getResult(0));
|
||||
WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
|
||||
}
|
||||
|
||||
|
@ -442,8 +442,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
if (!compile) return;
|
||||
auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
|
||||
if (!compile_launch || !compile_launch.WrapsSingleOp() ||
|
||||
compile_launch.GetBody().front().getName().getStringRef() !=
|
||||
"tf._TPUCompileMlir")
|
||||
!llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
|
||||
return;
|
||||
|
||||
auto module = while_op.getParentOfType<ModuleOp>();
|
||||
|
Loading…
x
Reference in New Issue
Block a user