From c442bc92a1110f597b40b50129cd21b5c4c3287b Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Wed, 19 Aug 2020 10:09:48 -0700 Subject: [PATCH] [tf2xla] Convert the result of xla::ReplicaId to S32 PiperOrigin-RevId: 327455603 Change-Id: Ic5efe1f80fc7a92debbc4f08853f824f8cdfb937 --- tensorflow/compiler/tf2xla/kernels/replica_id_op.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc b/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc index 46585a26769..71920372cde 100644 --- a/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc @@ -30,7 +30,8 @@ class XlaReplicaIdOp : public XlaOpKernel { }; void XlaReplicaIdOp::Compile(XlaOpKernelContext* ctx) { - ctx->SetOutput(0, xla::ReplicaId(ctx->builder())); + ctx->SetOutput( + 0, xla::ConvertElementType(xla::ReplicaId(ctx->builder()), xla::S32)); } REGISTER_XLA_OP(Name("XlaReplicaId"), XlaReplicaIdOp);