Ignore embedding ops when checking for side effects in head tail extraction outside compilation.
As the IR is in a topological sort order, relying on a sequential set of operations for all side effects is too restrictive, resulting in certain ops not being able to be head or tail extracted. For now embedding ops can be excluded from such checks. PiperOrigin-RevId: 329847234 Change-Id: Ic41e97456780cf898f0193afc43898b431b597e2
This commit is contained in:
parent
03b2e0b612
commit
d65352ddd3
@ -467,4 +467,71 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Test embedding ops can be head extracted and side effect analysis
|
||||
// predecessors are ignored.
|
||||
|
||||
// CHECK-LABEL: func @embedding_head_extraction
|
||||
func @embedding_head_extraction(%arg0: tensor<!tf.string>) {
|
||||
// CHECK: "tf_device.launch"()
|
||||
// CHECK-NEXT: "tf.EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.UnknownOp"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.UnknownOp"() : () -> ()
|
||||
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {_xla_outside_compilation = "cluster1", table_ids = [1, 2]} : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Test side effecting op after embedding op can be head extracted.
|
||||
|
||||
// CHECK-LABEL: func @op_after_embedding_head_extraction
|
||||
func @op_after_embedding_head_extraction() {
|
||||
// CHECK: "tf_device.launch"()
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.RecvTPUEmbeddingActivations"
|
||||
// CHECK-NEXT: "tf.SendTPUEmbeddingGradients"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.RecvTPUEmbeddingActivations"() {config = "test_config_recv_embedding"} : () -> tensor<512x256xf32>
|
||||
"tf.SendTPUEmbeddingGradients"(%0) {N = 1 : i64, NN = 0 : i64, config = "test_config_send_embedding", operand_segment_sizes = dense<[1, 0]> : vector<2xi32>} : (tensor<512x256xf32>) -> ()
|
||||
"tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Test side effecting op before embedding op can be tail extracted.
|
||||
|
||||
// CHECK-LABEL: func @op_before_embedding_tail_extraction
|
||||
func @op_before_embedding_tail_extraction() {
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.UnknownOp"
|
||||
// CHECK-NEXT: "tf.RecvTPUEmbeddingActivations"
|
||||
// CHECK-NEXT: "tf.SendTPUEmbeddingGradients"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
|
||||
// CHECK: "tf_device.launch"()
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.UnknownOp"() : () -> ()
|
||||
"tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%0 = "tf.RecvTPUEmbeddingActivations"() {config = "test_config_recv_embedding"} : () -> tensor<512x256xf32>
|
||||
"tf.SendTPUEmbeddingGradients"(%0) {N = 1 : i64, NN = 0 : i64, config = "test_config_send_embedding", operand_segment_sizes = dense<[1, 0]> : vector<2xi32>} : (tensor<512x256xf32>) -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
@ -37,6 +38,7 @@ limitations under the License.
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
@ -115,6 +117,14 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
|
||||
return launch;
|
||||
}
|
||||
|
||||
// Checks if an operation is a supported TPU embedding op.
|
||||
bool IsEmbeddingOp(Operation* op) {
|
||||
return isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
|
||||
TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
|
||||
TF::RecvTPUEmbeddingActivationsOp,
|
||||
TF::SendTPUEmbeddingGradientsOp>(op);
|
||||
}
|
||||
|
||||
// Returns a set of ops that are outside compiled and can be extracted to before
|
||||
// the TPU computation. These ops are either connected to the inputs of the TPU
|
||||
// computation or other ops that can be extracted, and have no operands from
|
||||
@ -136,10 +146,19 @@ llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
||||
// Check if the side effecting op right before this side effecting op, if
|
||||
// it is side effecting, can be head extracted. Because of op ordering due
|
||||
// to side effects, if this is not true, this op cannot be head extracted.
|
||||
// TODO(lyandy): Remove special handling of embedding ops. Currently the IR
|
||||
// is in a topological sort order and depending on that ordering, embedding
|
||||
// ops may prevent other ops from being head extracted.
|
||||
auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
|
||||
if (!predecessors.empty() &&
|
||||
!head_outside_compiled_ops.contains(predecessors.back()))
|
||||
continue;
|
||||
if (!predecessors.empty() && !IsEmbeddingOp(&cluster_op)) {
|
||||
bool skip = false;
|
||||
for (Operation* predecessor : llvm::reverse(predecessors)) {
|
||||
if (IsEmbeddingOp(predecessor)) continue;
|
||||
skip = !head_outside_compiled_ops.contains(predecessor);
|
||||
break;
|
||||
}
|
||||
if (skip) continue;
|
||||
}
|
||||
|
||||
auto walk_result = cluster_op.walk([&](Operation* op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
@ -225,11 +244,20 @@ void FindOutsideCompiledOpsAtTailAndClusterResults(
|
||||
// Check if the side effecting op right after this side effecting op, if
|
||||
// it is side effecting, can be tail extracted. Because of op ordering due
|
||||
// to side effects, if this is not true, this op cannot be tail extracted.
|
||||
// TODO(lyandy): Remove special handling of embedding ops. Currently the IR
|
||||
// is in a topological sort order and depending on that ordering, embedding
|
||||
// ops may prevent other ops from being tail extracted.
|
||||
auto successors = analysis.DirectControlSuccessors(
|
||||
&cluster_op, [&terminator](Operation* op) { return op != terminator; });
|
||||
if (!successors.empty() &&
|
||||
!tail_outside_compiled_ops_set.contains(successors.front()))
|
||||
continue;
|
||||
if (!successors.empty() && !IsEmbeddingOp(&cluster_op)) {
|
||||
bool skip = false;
|
||||
for (Operation* successor : successors) {
|
||||
if (IsEmbeddingOp(successor)) continue;
|
||||
skip = !tail_outside_compiled_ops_set.contains(successor);
|
||||
break;
|
||||
}
|
||||
if (skip) continue;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int, 4> results_to_forward;
|
||||
bool can_be_extracted =
|
||||
|
Loading…
x
Reference in New Issue
Block a user