diff --git a/tensorflow/core/ops/tpu_cross_replica_ops.cc b/tensorflow/core/ops/tpu_cross_replica_ops.cc index c26b49eb34b..adce0b51a05 100644 --- a/tensorflow/core/ops/tpu_cross_replica_ops.cc +++ b/tensorflow/core/ops/tpu_cross_replica_ops.cc @@ -40,6 +40,9 @@ REGISTER_OP("AllToAll") } int concat_dimension; int split_dimension; + int split_count; + + TF_RETURN_IF_ERROR(c->GetAttr("split_count", &split_count)); TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension)); @@ -58,14 +61,13 @@ REGISTER_OP("AllToAll") dims.resize(rank); for (int32 i = 0; i < rank; ++i) { - int64 in_idx = i; + dims[i] = c->Dim(input, i); if (i == concat_dimension) { - in_idx = split_dimension; - } else if (i == split_dimension) { - in_idx = concat_dimension; + dims[i] = c->MakeDim(c->Value(dims[i]) * split_count); + } + if (i == split_dimension) { + dims[i] = c->MakeDim(c->Value(dims[i]) / split_count); } - - dims[i] = c->Dim(input, in_idx); } c->set_output(0, c->MakeShape(dims));