[tf2xla] Convert the result of xla::ReplicaId to S32
PiperOrigin-RevId: 327455603 Change-Id: Ic5efe1f80fc7a92debbc4f08853f824f8cdfb937
This commit is contained in:
parent
caaaa55b93
commit
c442bc92a1
@ -30,7 +30,8 @@ class XlaReplicaIdOp : public XlaOpKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
void XlaReplicaIdOp::Compile(XlaOpKernelContext* ctx) {
|
void XlaReplicaIdOp::Compile(XlaOpKernelContext* ctx) {
|
||||||
ctx->SetOutput(0, xla::ReplicaId(ctx->builder()));
|
ctx->SetOutput(
|
||||||
|
0, xla::ConvertElementType(xla::ReplicaId(ctx->builder()), xla::S32));
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("XlaReplicaId"), XlaReplicaIdOp);
|
REGISTER_XLA_OP(Name("XlaReplicaId"), XlaReplicaIdOp);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user