Handle ReplicateOp op with packed inputs in following passes

* TPUSpaceToDepthPass
* TPUDynamicLayoutPass

Packed inputs were added to ReplicateOp after this pass was written so it only handled replicated inputs.

Also, added helper method GetOperandsForBlockArgument to ReplicateOp and using it for other helper methods.

Updated TPUSpaceToDepathPass existing test by adding devices attribute to trigger this codepath in the test.

For TPUSpaceToDepathPass test, I had to copy the replicate the existing complex test case instead of having a minimal test. I tried some things but the pass seems to be having many assumptions and will require closer look to trim the simplify it.

PiperOrigin-RevId: 353307553
Change-Id: I8c1ed8ca396da9bb28ffa4e479fb81813b01b17d
This commit is contained in:
Smit Hinsu 2021-01-22 13:28:26 -08:00 committed by TensorFlower Gardener
parent 98a4bb7a39
commit 445cdeb9f0
6 changed files with 210 additions and 37 deletions

View File

@ -624,15 +624,10 @@ bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) {
// block argument (of the replicate op) and a valid replica is provided.
unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument(
BlockArgument block_arg, unsigned replica) {
const int32_t num_replicas = nAttr().getInt();
assert(replica < num_replicas && block_arg.getOwner() == &GetBody());
MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
if (operands.size() == 1) return operands.front().getOperandNumber();
const unsigned num_replicated_args = GetNumReplicatedBlockArguments();
if (block_arg.getArgNumber() < num_replicated_args)
return block_arg.getArgNumber() * num_replicas + replica;
return block_arg.getArgNumber() - num_replicated_args +
replicated_inputs().size();
return operands[replica].getOperandNumber();
}
// Returns the operand being forwarded as a replicated/packed block argument for
@ -640,9 +635,34 @@ unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument(
// and a valid replica is provided.
Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg,
unsigned replica) {
const unsigned operand_index =
GetReplicaOperandIndexForBlockArgument(block_arg, replica);
return getOperand(operand_index);
MutableArrayRef<OpOperand> operands = GetOperandsForBlockArgument(block_arg);
if (operands.size() == 1) return operands.front().get();
return operands[replica].get();
}
// Returns the list of replica op operands that maps to the given block
// argument. Returns list with num_replicas elements for replicated operands
// and list with a single element for packed operands.
//
// Requires that block argument is of this replicate op.
MutableArrayRef<OpOperand> ReplicateOp::GetOperandsForBlockArgument(
BlockArgument block_arg) {
assert(block_arg.getOwner() == &GetBody());
unsigned arg_number = block_arg.getArgNumber();
unsigned num_replicated_args = GetNumReplicatedBlockArguments();
int32_t num_replicas = nAttr().getInt();
MutableArrayRef<OpOperand> operands = getOperation()->getOpOperands();
// All replicated arguments are before packed arguments so return replicated
// operands if the given argument is one of the replicated arguments.
if (arg_number < num_replicated_args)
return operands.slice(arg_number * num_replicas, num_replicas);
operands = operands.drop_front(num_replicated_args * num_replicas);
arg_number -= num_replicated_args;
return operands.slice(arg_number, 1);
}
// Checks if a tf_device.replicate wraps a single operation and the single

View File

@ -289,6 +289,7 @@ For example:
bool IsPackedBlockArgument(BlockArgument block_arg);
unsigned GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg, unsigned replica);
Value GetReplicaOperandForBlockArgument(BlockArgument block_arg, unsigned replica);
MutableArrayRef<OpOperand> GetOperandsForBlockArgument(BlockArgument block_arg);
bool WrapsSingleOp();
}];

View File

@ -277,6 +277,88 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) ->
// -----
// Tests that the pass can transform replicated execution with packed inputs.
// CHECK: func @replicated_packed(%[[ARG0:.*]]: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32>
func @replicated_packed(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
// CHECK: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
// CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
// CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
// CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
// CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"}
// CHECK: tf_device.replicate(%[[COPY0]] as %[[R0:.*]]: tensor<3x3x1x32xf32>, %[[COPY1]] as %[[R1:.*]]: tensor<3x3x1x32xf32>)
%5:2 = tf_device.replicate(%2#0 as %r0: tensor<3x3x1x32xf32>, %2#1 as %r1: tensor<3x3x1x32xf32>)
{n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}} {
// CHECK: "tf.TPUExecute"(%[[R0]], %[[R1]], %[[COMPILE]]#1)
%execute = "tf_device.launch"() ( {
%4 = "tf.TPUExecute"(%r0, %r1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
tf_device.return %4 : tensor<i32>
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i32>
tf_device.return %execute : tensor<i32>
}
return %5#0 : tensor<i32>
}
// -----
// Tests that the pass can transform replicated execution with both replicated
// and packed operands.
// CHECK: func @replicated(%[[ARG0:.*]]: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32>
func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}, %arg1: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
// CHECK: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
// CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
// CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
// CHECK: %[[ITER1:.*]] = "tf.IteratorGetNext"
%3 = "tf.IteratorGetNext"(%arg1) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> tensor<3x3x1x32xf32>
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
// CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
// CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"}
// CHECK-DAG: %[[COPY2:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]], %[[LAYOUT0]]) {device = "/device:TPU:1"}
// CHECK: tf_device.replicate([%[[COPY0]], %[[COPY2]]] as %[[R0:.*]]: tensor<3x3x1x32xf32>, %[[COPY1]] as %[[R1:.*]]: tensor<3x3x1x32xf32>)
%5:2 = tf_device.replicate([%2#0, %3] as %r0: tensor<3x3x1x32xf32>, %2#1 as %r1: tensor<3x3x1x32xf32>)
{n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}} {
// CHECK: "tf.TPUExecute"(%[[R0]], %[[R1]], %[[COMPILE]]#1)
%execute = "tf_device.launch"() ( {
%4 = "tf.TPUExecute"(%r0, %r1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
tf_device.return %4 : tensor<i32>
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i32>
tf_device.return %execute : tensor<i32>
}
return %5#0 : tensor<i32>
}
// -----
// Tests that the pass does not change inputs inside replicate.
// CHECK-LABEL: func @inside_replicated

View File

@ -95,7 +95,81 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSI
// CHECK: %[[INPUT01:.*]] = "tf.IteratorGetNext"
// CHECK-DAG: %[[SPACETODEPTH01:.*]] = "tf.SpaceToDepth"([[INPUT01:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
%1:2 = "tf.IteratorGetNext"(%arg4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>)
tf_device.replicate([%0#0, %1#0] as %arg13: tensor<2x224x224x3xf32>, [%0#1, %1#1] as %arg14: tensor<2x1xf32>, %arg6 as %arg15: tensor<*x!tf.resource<tensor<7x7x3x64xf32>>>, %arg8 as %arg16: tensor<*x!tf.resource<tensor<1001xf32>>>, %arg7 as %arg17: tensor<*x!tf.resource<tensor<64x1001xf32>>>, %arg9 as %arg18: tensor<*x!tf.resource<tensor<f32>>>, %arg10 as %arg19: tensor<*x!tf.resource<tensor<f32>>>, %arg11 as %arg20: tensor<*x!tf.resource<tensor<f32>>>, %arg12 as %arg21: tensor<*x!tf.resource<tensor<f32>>>) {_mirrored_variable_indices = [2, 3, 4, 5, 6, 7, 8], _replicated_input_indices = [1, 2, -1, -1, -1, -1, -1, -1, -1], devices = {}, n = 2 : i32} {
tf_device.replicate([%0#0, %1#0] as %arg13: tensor<2x224x224x3xf32>, [%0#1, %1#1] as %arg14: tensor<2x1xf32>, %arg6 as %arg15: tensor<*x!tf.resource<tensor<7x7x3x64xf32>>>, %arg8 as %arg16: tensor<*x!tf.resource<tensor<1001xf32>>>, %arg7 as %arg17: tensor<*x!tf.resource<tensor<64x1001xf32>>>, %arg9 as %arg18: tensor<*x!tf.resource<tensor<f32>>>, %arg10 as %arg19: tensor<*x!tf.resource<tensor<f32>>>, %arg11 as %arg20: tensor<*x!tf.resource<tensor<f32>>>, %arg12 as %arg21: tensor<*x!tf.resource<tensor<f32>>>) {_mirrored_variable_indices = [2, 3, 4, 5, 6, 7, 8], _replicated_input_indices = [1, 2, -1, -1, -1, -1, -1, -1, -1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
%2 = "tf.ReadVariableOp"(%arg15) : (tensor<*x!tf.resource<tensor<7x7x3x64xf32>>>) -> tensor<7x7x3x64xf32>
%3 = "tf.ReadVariableOp"(%arg16) : (tensor<*x!tf.resource<tensor<1001xf32>>>) -> tensor<1001xf32>
%4 = "tf.ReadVariableOp"(%arg17) : (tensor<*x!tf.resource<tensor<64x1001xf32>>>) -> tensor<64x1001xf32>
%5 = "tf.ReadVariableOp"(%arg18) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%6 = "tf.ReadVariableOp"(%arg19) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%7 = "tf.ReadVariableOp"(%arg20) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%8 = "tf.ReadVariableOp"(%arg21) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%9:4 = "tf_device.cluster_func"(%arg13, %arg14, %2, %4, %3, %5, %6, %7, %8) {_tpu_replicate = "cluster_eval_step", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], func = @_func, host_compute_core = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], num_cores_per_replica = 1 : i64, output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : (tensor<2x224x224x3xf32>, tensor<2x1xf32>, tensor<7x7x3x64xf32>, tensor<64x1001xf32>, tensor<1001xf32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
"tf.AssignVariableOp"(%arg18, %9#0) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
"tf.AssignVariableOp"(%arg19, %9#1) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
"tf.AssignVariableOp"(%arg20, %9#2) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
"tf.AssignVariableOp"(%arg21, %9#3) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
tf_device.return
}
return
}
// CHECK-LABEL: func private @_func
// CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
func private @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%3 = "tf.Const"() {value = dense<[[0, 1]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%4 = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
%5 = "tf.Const"() {value = dense<2.500000e-01> : tensor<f32>} : () -> tensor<f32>
%6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%7 = "tf.Const"() {value = dense<[-1, 1001]> : tensor<2xi32>} : () -> tensor<2xi32>
%8 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%9 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%10 = "tf.Const"() {value = dense<[[0, 0], [3, 3], [3, 3], [0, 0]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
%11 = "tf.Pad"(%arg0, %10) : (tensor<2x224x224x3xf32>, tensor<4x2xi32>) -> tensor<2x230x230x3xf32>
%12 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2x1xf32>) -> tensor<2x1xi64>
%13 = "tf.Reshape"(%12, %9) : (tensor<2x1xi64>, tensor<1xi32>) -> tensor<2xi64>
%14 = "tf.Squeeze"(%arg1) {squeeze_dims = [-1]} : (tensor<2x1xf32>) -> tensor<2xf32>
// CHECK: "tf.Conv2D"
// CHECK-SAME: strides = [1, 1, 1, 1]
// CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4x4x12x64xf32>) -> tensor<2x112x112x64xf32>
%15 = "tf.Conv2D"(%11, %arg2) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> tensor<2x112x112x64xf32>
%16 = "tf.Mean"(%15, %8) {keep_dims = false} : (tensor<2x112x112x64xf32>, tensor<2xi32>) -> tensor<2x64xf32>
%17 = "tf.MatMul"(%16, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x64xf32>, tensor<64x1001xf32>) -> tensor<2x1001xf32>
%18 = "tf.BiasAdd"(%17, %arg4) {data_format = "NHWC"} : (tensor<2x1001xf32>, tensor<1001xf32>) -> tensor<2x1001xf32>
%19 = "tf.Reshape"(%18, %7) : (tensor<2x1001xf32>, tensor<2xi32>) -> tensor<2x1001xf32>
%loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%19, %13) : (tensor<2x1001xf32>, tensor<2xi64>) -> (tensor<2xf32>, tensor<2x1001xf32>)
%20 = "tf.Sum"(%loss, %6) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
%21 = "tf.Mul"(%20, %5) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%22 = "tf.Sum"(%21, %4) {keep_dims = false} : (tensor<f32>, tensor<0xi32>) -> tensor<f32>
%23 = "tf.CrossReplicaSum"(%22, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
%24 = "tf.Softmax"(%18) : (tensor<2x1001xf32>) -> tensor<2x1001xf32>
%25 = "tf.ArgMax"(%24, %2) : (tensor<2x1001xf32>, tensor<i32>) -> tensor<2xi64>
%26 = "tf.Cast"(%25) {Truncate = false} : (tensor<2xi64>) -> tensor<2xf32>
%27 = "tf.Equal"(%14, %26) {incompatible_shape_error = true} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
%28 = "tf.Cast"(%27) {Truncate = false} : (tensor<2xi1>) -> tensor<2xf32>
%29 = "tf.Sum"(%28, %6) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
%30 = "tf.CrossReplicaSum"(%29, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
%31 = "tf.AddV2"(%arg5, %23) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%32 = "tf.CrossReplicaSum"(%1, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
%33 = "tf.AddV2"(%arg6, %32) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%34 = "tf.AddV2"(%arg7, %30) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%35 = "tf.CrossReplicaSum"(%0, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
%36 = "tf.AddV2"(%arg8, %35) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %31, %33, %34, %36 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
}
}
// -----
// Tests for space to depth host and device transform with replicate packed inputs.
module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSITE:0" = {}, "/job:localhost/replica:0/task:0/device:CPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:1" = {}, "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 458 : i32}} {
func @main(%arg0: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor<!tf.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor<!tf.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg5: tensor<!tf.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf.resource<tensor<7x7x3x64xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg7: tensor<*x!tf.resource<tensor<64x1001xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg8: tensor<*x!tf.resource<tensor<1001xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg9: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg10: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg11: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg12: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}) attributes {tf.entry_function = {control_outputs = "IteratorGetNext,IteratorGetNext_1,CrossReplicaSum,AssignAddVariableOp,CrossReplicaSum_1,AssignAddVariableOp_1,CrossReplicaSum_2,AssignAddVariableOp_2,CrossReplicaSum_3,AssignAddVariableOp_3", inputs = "iterator,iterator_1,iterator_2,iterator_3,iterator_4,iterator_5,resnet50_conv1_conv2d_conv1_kernel_140365606309224_handle_inputs_0,resnet50_fc1000_matmul_fc1000_kernel_140365944145960_handle_inputs_0,resnet50_fc1000_biasadd_fc1000_bias_140365944146240_handle_inputs_0,total_140366323758976_handle_inputs_0,count_140366323759312_handle_inputs_0,total_140366323760264_handle_inputs_0,count_140366323760600_handle_inputs_0", outputs = ""}} {
// CHECK: %[[INPUT00:.*]] = "tf.IteratorGetNext"
// CHECK-DAG: %[[SPACETODEPTH00:.*]] = "tf.SpaceToDepth"([[INPUT00:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
%0:2 = "tf.IteratorGetNext"(%arg2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>)
tf_device.replicate(%0#0 as %arg13: tensor<2x224x224x3xf32>, %0#1 as %arg14: tensor<2x1xf32>, %arg6 as %arg15: tensor<*x!tf.resource<tensor<7x7x3x64xf32>>>, %arg8 as %arg16: tensor<*x!tf.resource<tensor<1001xf32>>>, %arg7 as %arg17: tensor<*x!tf.resource<tensor<64x1001xf32>>>, %arg9 as %arg18: tensor<*x!tf.resource<tensor<f32>>>, %arg10 as %arg19: tensor<*x!tf.resource<tensor<f32>>>, %arg11 as %arg20: tensor<*x!tf.resource<tensor<f32>>>, %arg12 as %arg21: tensor<*x!tf.resource<tensor<f32>>>) {_mirrored_variable_indices = [2, 3, 4, 5, 6, 7, 8], _replicated_input_indices = [1, 2, -1, -1, -1, -1, -1, -1, -1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
%2 = "tf.ReadVariableOp"(%arg15) : (tensor<*x!tf.resource<tensor<7x7x3x64xf32>>>) -> tensor<7x7x3x64xf32>
%3 = "tf.ReadVariableOp"(%arg16) : (tensor<*x!tf.resource<tensor<1001xf32>>>) -> tensor<1001xf32>
%4 = "tf.ReadVariableOp"(%arg17) : (tensor<*x!tf.resource<tensor<64x1001xf32>>>) -> tensor<64x1001xf32>

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
@ -181,16 +182,15 @@ void HandleInput(Value input, const int64_t execute_arg_index,
bool HandleReplicatedInputs(
const int64_t execute_arg_index, Value compilation_key,
tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch,
const int64_t replicate_arg_index, tf_device::ReplicateOp replicate,
mlir::BlockArgument replicate_arg, tf_device::ReplicateOp replicate,
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
// We need to know the devices to copy to.
if (!replicate.devices()) return false;
int64_t num_replicas = replicate.n();
auto inputs = replicate.getOperands()
.drop_front(replicate_arg_index * num_replicas)
.take_front(num_replicas);
MutableArrayRef<OpOperand> inputs =
replicate.GetOperandsForBlockArgument(replicate_arg);
for (auto entry : llvm::enumerate(inputs)) {
auto input_op = entry.value().getDefiningOp();
auto input_op = entry.value().get().getDefiningOp();
if (!input_op || !IsSupportedInputOp(input_op, resource_alias_analysis))
return false;
}
@ -199,8 +199,9 @@ bool HandleReplicatedInputs(
compile_launch, &builder);
builder.setInsertionPoint(replicate);
for (auto entry : llvm::enumerate(inputs)) {
auto copy_with_layout = BuildCopyWithLayout(
execute_launch, compile_launch, get_layout, entry.value(), &builder);
auto copy_with_layout =
BuildCopyWithLayout(execute_launch, compile_launch, get_layout,
entry.value().get(), &builder);
auto device_list = replicate.devices()
.getValue()
@ -209,8 +210,7 @@ bool HandleReplicatedInputs(
copy_with_layout->setAttr(kDeviceAttr,
device_list.getValue()[entry.index()]);
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
copy_with_layout);
entry.value().set(copy_with_layout);
}
return true;
}
@ -247,9 +247,8 @@ void HandleCompileAndExecutes(
// replicated input (defining ops will be outside the replicate node).
if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
!HandleReplicatedInputs(execute_arg_index, execute.key(),
execute_launch, compile_launch,
block_arg.getArgNumber(), maybe_replicate,
resource_alias_analysis)) {
execute_launch, compile_launch, block_arg,
maybe_replicate, resource_alias_analysis)) {
continue;
}
} else {

View File

@ -453,26 +453,23 @@ bool HandleHostReplicatedInputs(int64_t index,
BlockArgument block_arg,
tf_device::ReplicateOp replicate,
int32_t block_size) {
int64_t replicate_arg_index = block_arg.getArgNumber();
// We need to know the devices to copy to.
if (!replicate.devices()) return false;
int64_t num_replicas = replicate.n();
// Gets inputs at replicate_arg_index for each replica.
auto inputs = replicate.getOperands()
.drop_front(replicate_arg_index * num_replicas)
.take_front(num_replicas);
for (auto input : inputs) {
auto input_op = input.getDefiningOp();
MutableArrayRef<OpOperand> inputs =
replicate.GetOperandsForBlockArgument(block_arg);
for (auto& input : inputs) {
auto input_op = input.get().getDefiningOp();
if (!input_op || !IsSupportedHostInputOp(input_op)) return false;
}
for (auto entry : llvm::enumerate(inputs)) {
auto ranked_type = entry.value().getType().dyn_cast<RankedTensorType>();
Value input = entry.value().get();
auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
if (!ranked_type) return false;
auto input_shape = ranked_type.getShape();
auto space_to_depth =
BuildSpaceToDepth(cluster_func, entry.value(), block_size, input_shape);
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
space_to_depth);
BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
entry.value().set(space_to_depth);
block_arg.setType(space_to_depth.getType());
}
return true;