From cb01a295da6787668ce8ccdaeac7fb439afe92e9 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 9 Aug 2019 19:09:39 -0700
Subject: [PATCH] Fix tpu_ops.all_to_all op output shape.

PiperOrigin-RevId: 262676072
---
 tensorflow/core/ops/tpu_cross_replica_ops.cc | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/tensorflow/core/ops/tpu_cross_replica_ops.cc b/tensorflow/core/ops/tpu_cross_replica_ops.cc
index c26b49eb34b..adce0b51a05 100644
--- a/tensorflow/core/ops/tpu_cross_replica_ops.cc
+++ b/tensorflow/core/ops/tpu_cross_replica_ops.cc
@@ -40,6 +40,9 @@ REGISTER_OP("AllToAll")
       }
       int concat_dimension;
       int split_dimension;
+      int split_count;
+
+      TF_RETURN_IF_ERROR(c->GetAttr("split_count", &split_count));
 
       TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));
 
@@ -58,14 +61,13 @@ REGISTER_OP("AllToAll")
       dims.resize(rank);
 
       for (int32 i = 0; i < rank; ++i) {
-        int64 in_idx = i;
+        dims[i] = c->Dim(input, i);
         if (i == concat_dimension) {
-          in_idx = split_dimension;
-        } else if (i == split_dimension) {
-          in_idx = concat_dimension;
+          dims[i] = c->MakeDim(c->Value(dims[i]) * split_count);
+        }
+        if (i == split_dimension) {
+          dims[i] = c->MakeDim(c->Value(dims[i]) / split_count);
         }
-
-        dims[i] = c->Dim(input, in_idx);
       }
 
       c->set_output(0, c->MakeShape(dims));