[tf2xla] Convert the result of xla::ReplicaId to S32

PiperOrigin-RevId: 327455603
Change-Id: Ic5efe1f80fc7a92debbc4f08853f824f8cdfb937
This commit is contained in:
David Majnemer 2020-08-19 10:09:48 -07:00 committed by TensorFlower Gardener
parent caaaa55b93
commit c442bc92a1

View File

@ -30,7 +30,8 @@ class XlaReplicaIdOp : public XlaOpKernel {
};
void XlaReplicaIdOp::Compile(XlaOpKernelContext* ctx) {
ctx->SetOutput(0, xla::ReplicaId(ctx->builder()));
ctx->SetOutput(
0, xla::ConvertElementType(xla::ReplicaId(ctx->builder()), xla::S32));
}
REGISTER_XLA_OP(Name("XlaReplicaId"), XlaReplicaIdOp);