Add replica id attribute to TPU Embedding ops during replicate to islands pass.
PiperOrigin-RevId: 313682388 Change-Id: I0e72b06b5db5c4f92b62562de523adaa01c2fa30
This commit is contained in:
parent
2003db55b1
commit
77cb02204f
@ -141,3 +141,38 @@ func @replicate_result(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
// CHECK: %[[REPLICA_0:.*]]:2, %{{.*}} = tf_executor.island
|
||||
// CHECK: %[[REPLICA_1:.*]]:2, %{{.*}} = tf_executor.island
|
||||
// CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0, %[[REPLICA_0]]#1, %[[REPLICA_1]]#1
|
||||
|
||||
|
||||
// Tests replica id is added correctly.
|
||||
// CHECK-LABEL: func @replica_id_attr_added
|
||||
func @replica_id_attr_added(%arg0: tensor<!tf.string>, %arg1: tensor<!tf.string>) {
|
||||
tf_executor.graph {
|
||||
tf_executor.island {
|
||||
tf_device.replicate([%arg0, %arg1] as %arg2: tensor<!tf.string>) {n = 2 : i32} {
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg2){table_ids = [1, 2]} : (tensor<!tf.string>) -> ()
|
||||
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg2){table_ids = [1, 2]} : (tensor<!tf.string>) -> ()
|
||||
"tf.A"(%arg2) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: tf_executor.island
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"
|
||||
// CHECK-SAME: _xla_replica_id = 0
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||
// CHECK-SAME: _xla_replica_id = 0
|
||||
// CHECK: "tf.A"
|
||||
// CHECK-NOT: _xla_replica_id
|
||||
// CHECK: tf_executor.island
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"
|
||||
// CHECK-SAME: _xla_replica_id = 1
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingRaggedTensorBatch"
|
||||
// CHECK-SAME: _xla_replica_id = 1
|
||||
// CHECK: "tf.A"
|
||||
// CHECK-NOT: _xla_replica_id
|
||||
// CHECK: tf_executor.fetch
|
||||
|
@ -37,18 +37,37 @@ limitations under the License.
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFDevice {
|
||||
namespace {
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
constexpr char kReplicaIdAttr[] = "_xla_replica_id";
|
||||
|
||||
struct ReplicateToIslandPass
|
||||
: public PassWrapper<ReplicateToIslandPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Returns whether op requires `_xla_replica_id` attribute.
|
||||
bool RequiresReplicaIDAttribute(Operation* op) {
|
||||
return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp>(op) ||
|
||||
llvm::isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op);
|
||||
}
|
||||
|
||||
// Adds integer attribute that represents replica id for replicated ops that
|
||||
// require replica id attribute.
|
||||
void AddReplicaIdToOpsInReplicatedRegion(OpBuilder* builder, Region* region,
|
||||
const int replica_id) {
|
||||
region->walk([&](Operation* replicated_op) {
|
||||
if (RequiresReplicaIDAttribute(replicated_op))
|
||||
replicated_op->setAttr(kReplicaIdAttr,
|
||||
builder->getI32IntegerAttr(replica_id));
|
||||
});
|
||||
}
|
||||
|
||||
// Creates islands per replica from `tf_device.replicate` region. If for a
|
||||
// `tf_device.launch` op the device is an aliased device of the
|
||||
// `tf_device.replicate`, the device will be remapped to an explicit device
|
||||
@ -90,6 +109,14 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
|
||||
// Copy over replicate region into replica island.
|
||||
replicate_op.body().cloneInto(&replica.body(), mapping);
|
||||
|
||||
// TODO(b/157624749): Replace this with better abstraction to
|
||||
// differentiate ops for different replicas.
|
||||
// Some ops, such as XlaHostCompute op or TPU Embedding ops, require
|
||||
// replica id to be added as an op attribute to be used during
|
||||
// execution. Handle such ops separately and add an integer attribute
|
||||
// that represents replica id.
|
||||
AddReplicaIdToOpsInReplicatedRegion(builder, &replica.body(), i);
|
||||
|
||||
// Map aliased devices to explicit devices based on replica.
|
||||
if (has_devices) {
|
||||
replica.walk([&](tf_device::LaunchOp launch) {
|
||||
|
Loading…
Reference in New Issue
Block a user