Replace using OperationState and checking op name for determining if an op is tf._TPUCompileMlir or tf.TPUCompileSucceededAssert.

More TPU specific ops were added to ODS, so they can be used directly instead.

PiperOrigin-RevId: 299149173
Change-Id: I45c7eace58f4a36043c9b29de6adb9efb4f574d9
This commit is contained in:
Andy Ly 2020-03-05 11:25:31 -08:00 committed by TensorFlower Gardener
parent cbbbe1dc02
commit d5504fbc78
4 changed files with 11 additions and 47 deletions

View File

@ -570,7 +570,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[ARG_2_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_2]])
%0 = "tf_device.launch_func"(%arg0, %arg1, %arg2, %arg3) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32>
// CHECK: "tf._TPUCompileMlir"(%[[ARG_0_SHAPE]], %[[ARG_2_SHAPE]])
// CHECK-SAME: NumDynamicShapes = 2
return %0: tensor<8xi32>
}
@ -594,7 +593,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
@ -640,7 +638,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
@ -712,7 +709,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
@ -762,7 +758,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
@ -808,7 +803,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
@ -862,7 +856,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
@ -910,7 +903,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
@ -925,7 +917,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
@ -968,7 +959,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
@ -983,7 +973,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
@ -1022,7 +1011,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main

View File

@ -244,8 +244,7 @@ void TPUDynamicLayoutPass::runOnFunction() {
if (!compile || !compile->getResult(1).hasOneUse()) return;
auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
if (!compile_launch || !compile_launch.WrapsSingleOp() ||
compile_launch.GetBody().front().getName().getStringRef() !=
"tf._TPUCompileMlir")
!llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
return;
executes_and_compiles.emplace_back(execute, compile_launch);
});

View File

@ -361,9 +361,6 @@ Operation* BuildCompileOp(
int num_cores_per_replica, llvm::StringRef compilation_device,
llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
OpBuilder* builder) {
// TODO(b/139377366): Use tf_tpu.compile build method when it is defined.
OperationState compile_op_state(launch_func.getLoc(), "tf._TPUCompileMlir");
// Set metadata from attributes.
tensorflow::tpu::TPUCompileMetadataProto metadata;
if (failed(SetMetadataProtoFromLaunchFuncOp(
@ -377,9 +374,6 @@ Operation* BuildCompileOp(
else
metadata.SerializeToString(&txt_metadata);
compile_op_state.addAttribute("metadata",
builder->getStringAttr(txt_metadata));
// Build a shape op for each input to launch_func.
// TODO(b/139377366): When shape inference is ready, we can use compile time
// shape inference to get inputs that have static shapes and only use shape
@ -399,36 +393,22 @@ Operation* BuildCompileOp(
operand_and_idx.value());
compile_op_operands.emplace_back(shape_op.getResult());
}
compile_op_state.addOperands(compile_op_operands);
compile_op_state.addAttribute(
"NumDynamicShapes",
builder->getI64IntegerAttr(compile_op_operands.size()));
FlatSymbolRefAttr func_attr =
launch_func.getAttrOfType<FlatSymbolRefAttr>("func");
if (!func_attr) {
launch_func.emitOpError("does not have `func` attribute");
return nullptr;
}
FlatSymbolRefAttr func_attr = launch_func.funcAttr();
FuncOp func = launch_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
func_attr.getValue());
std::string txt_module;
if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
compile_op_state.addAttribute("mlir_module",
builder->getStringAttr(txt_module));
// Result #0 is a string indicating whether compilation is successful or not.
compile_op_state.addTypes(
RankedTensorType::get({}, builder->getType<TF::StringType>()));
auto result_type =
RankedTensorType::get({}, builder->getType<TF::StringType>());
// Result #1 is key to look up executable binary in compilation cache.
compile_op_state.addTypes(
RankedTensorType::get({}, builder->getType<TF::StringType>()));
auto compile_op = builder->create<TF::_TPUCompileMlirOp>(
launch_func.getLoc(), /*compilation_status=*/result_type,
/*program=*/result_type, compile_op_operands, txt_module, txt_metadata);
Operation* compile_op = builder->createOperation(compile_op_state);
return WrapOpInLaunch(builder, compile_op->getLoc(), compile_op,
return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op,
compilation_device);
}
@ -548,10 +528,8 @@ tf_device::LaunchOp AssignDevicesToReplicatedExecute(
void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
llvm::StringRef compilation_device,
OpBuilder* builder) {
OperationState assert_op_state(compile_op->getLoc(),
"tf.TPUCompileSucceededAssert");
assert_op_state.addOperands(compile_op->getResult(0));
Operation* assert_op = builder->createOperation(assert_op_state);
auto assert_op = builder->create<TF::TPUCompileSucceededAssertOp>(
compile_op->getLoc(), compile_op->getResult(0));
WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
}

View File

@ -442,8 +442,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
if (!compile) return;
auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
if (!compile_launch || !compile_launch.WrapsSingleOp() ||
compile_launch.GetBody().front().getName().getStringRef() !=
"tf._TPUCompileMlir")
!llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
return;
auto module = while_op.getParentOfType<ModuleOp>();