Add replica id attribute to TPU Embedding ops during replicate to islands pass.

PiperOrigin-RevId: 313682388
Change-Id: I0e72b06b5db5c4f92b62562de523adaa01c2fa30
This commit is contained in:
A. Unique TensorFlower 2020-05-28 16:33:42 -07:00 committed by TensorFlower Gardener
parent 2003db55b1
commit 77cb02204f
2 changed files with 62 additions and 0 deletions

View File

@ -141,3 +141,38 @@ func @replicate_result(%arg0: tensor<i1>, %arg1: tensor<i1>) {
// CHECK: %[[REPLICA_0:.*]]:2, %{{.*}} = tf_executor.island
// CHECK: %[[REPLICA_1:.*]]:2, %{{.*}} = tf_executor.island
// CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0, %[[REPLICA_0]]#1, %[[REPLICA_1]]#1
// Tests replica id is added correctly.
// CHECK-LABEL: func @replica_id_attr_added
func @replica_id_attr_added(%arg0: tensor<!tf.string>, %arg1: tensor<!tf.string>) {
tf_executor.graph {
tf_executor.island {
tf_device.replicate([%arg0, %arg1] as %arg2: tensor<!tf.string>) {n = 2 : i32} {
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg2){table_ids = [1, 2]} : (tensor<!tf.string>) -> ()
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg2){table_ids = [1, 2]} : (tensor<!tf.string>) -> ()
"tf.A"(%arg2) : (tensor<!tf.string>) -> ()
tf_device.return
}
tf_executor.yield
}
tf_executor.fetch
}
return
}
// CHECK: tf_executor.island
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"
// CHECK-SAME: _xla_replica_id = 0
// CHECK: "tf.EnqueueTPUEmbeddingRaggedTensorBatch"
// CHECK-SAME: _xla_replica_id = 0
// CHECK: "tf.A"
// CHECK-NOT: _xla_replica_id
// CHECK: tf_executor.island
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"
// CHECK-SAME: _xla_replica_id = 1
// CHECK: "tf.EnqueueTPUEmbeddingRaggedTensorBatch"
// CHECK-SAME: _xla_replica_id = 1
// CHECK: "tf.A"
// CHECK-NOT: _xla_replica_id
// CHECK: tf_executor.fetch

View File

@ -37,18 +37,37 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/platform/logging.h"
namespace mlir {
namespace TFDevice {
namespace {
constexpr char kDeviceAttr[] = "device";
constexpr char kReplicaIdAttr[] = "_xla_replica_id";
struct ReplicateToIslandPass
: public PassWrapper<ReplicateToIslandPass, FunctionPass> {
void runOnFunction() override;
};
// Returns whether op requires `_xla_replica_id` attribute.
bool RequiresReplicaIDAttribute(Operation* op) {
return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp>(op) ||
llvm::isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op);
}
// Adds integer attribute that represents replica id for replicated ops that
// require replica id attribute.
void AddReplicaIdToOpsInReplicatedRegion(OpBuilder* builder, Region* region,
const int replica_id) {
region->walk([&](Operation* replicated_op) {
if (RequiresReplicaIDAttribute(replicated_op))
replicated_op->setAttr(kReplicaIdAttr,
builder->getI32IntegerAttr(replica_id));
});
}
// Creates islands per replica from `tf_device.replicate` region. If for a
// `tf_device.launch` op the device is an aliased device of the
// `tf_device.replicate`, the device will be remapped to an explicit device
@ -90,6 +109,14 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
// Copy over replicate region into replica island.
replicate_op.body().cloneInto(&replica.body(), mapping);
// TODO(b/157624749): Replace this with better abstraction to
// differentiate ops for different replicas.
// Some ops, such as XlaHostCompute op or TPU Embedding ops, require
// replica id to be added as an op attribute to be used during
// execution. Handle such ops separately and add an integer attribute
// that represents replica id.
AddReplicaIdToOpsInReplicatedRegion(builder, &replica.body(), i);
// Map aliased devices to explicit devices based on replica.
if (has_devices) {
replica.walk([&](tf_device::LaunchOp launch) {