Fix tpu_ops.all_to_all op output shape.
PiperOrigin-RevId: 262676072
This commit is contained in:
parent
e6dc56ced2
commit
cb01a295da
@ -40,6 +40,9 @@ REGISTER_OP("AllToAll")
|
|||||||
}
|
}
|
||||||
int concat_dimension;
|
int concat_dimension;
|
||||||
int split_dimension;
|
int split_dimension;
|
||||||
|
int split_count;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("split_count", &split_count));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));
|
TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));
|
||||||
|
|
||||||
@ -58,14 +61,13 @@ REGISTER_OP("AllToAll")
|
|||||||
dims.resize(rank);
|
dims.resize(rank);
|
||||||
|
|
||||||
for (int32 i = 0; i < rank; ++i) {
|
for (int32 i = 0; i < rank; ++i) {
|
||||||
int64 in_idx = i;
|
dims[i] = c->Dim(input, i);
|
||||||
if (i == concat_dimension) {
|
if (i == concat_dimension) {
|
||||||
in_idx = split_dimension;
|
dims[i] = c->MakeDim(c->Value(dims[i]) * split_count);
|
||||||
} else if (i == split_dimension) {
|
}
|
||||||
in_idx = concat_dimension;
|
if (i == split_dimension) {
|
||||||
|
dims[i] = c->MakeDim(c->Value(dims[i]) / split_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
dims[i] = c->Dim(input, in_idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c->set_output(0, c->MakeShape(dims));
|
c->set_output(0, c->MakeShape(dims));
|
||||||
|
Loading…
Reference in New Issue
Block a user