[tf2xla] Convert the result of xla::ReplicaId to S32
PiperOrigin-RevId: 327455603 Change-Id: Ic5efe1f80fc7a92debbc4f08853f824f8cdfb937
This commit is contained in:
parent
caaaa55b93
commit
c442bc92a1
@ -30,7 +30,8 @@ class XlaReplicaIdOp : public XlaOpKernel {
|
||||
};
|
||||
|
||||
void XlaReplicaIdOp::Compile(XlaOpKernelContext* ctx) {
|
||||
ctx->SetOutput(0, xla::ReplicaId(ctx->builder()));
|
||||
ctx->SetOutput(
|
||||
0, xla::ConvertElementType(xla::ReplicaId(ctx->builder()), xla::S32));
|
||||
}
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaReplicaId"), XlaReplicaIdOp);
|
||||
|
Loading…
x
Reference in New Issue
Block a user