Fix NCHWToNCHW_VECT_C data formation conversion in tf2xla bridge

Newly created dimension should be inserted after the old dim so that it can be moved at the end with the transform permutation. Results after this fix matches the inferred shape.

PiperOrigin-RevId: 327283205
Change-Id: Iead28dbd32b9c78652cd1348ce683c38c8199d83
This commit is contained in:
Smit Hinsu 2020-08-18 12:30:14 -07:00 committed by TensorFlower Gardener
parent 7f01242aa1
commit 27f7de6d83
2 changed files with 10 additions and 10 deletions

View File

@ -1118,10 +1118,10 @@ class UnaryOpsTest(xla_test.XLATestCase):
self._assertOpOutputMatchesExpected(
make_op("NCHW_VECT_C"),
np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)),
expected=np.array([[[[[0, 1], [8, 9]], [[16, 17], [24, 25]]],
[[[2, 3], [10, 11]], [[18, 19], [26, 27]]],
[[[4, 5], [12, 13]], [[20, 21], [28, 29]]],
[[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]],
expected=np.array([[[[[0, 1, 2, 3], [8, 9, 10, 11]],
[[16, 17, 18, 19], [24, 25, 26, 27]]],
[[[4, 5, 6, 7], [12, 13, 14, 15]],
[[20, 21, 22, 23], [28, 29, 30, 31]]]]],
dtype=dtype))
@test_util.disable_mlir_bridge(
@ -1172,11 +1172,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
self._assertOpOutputMatchesExpected(
make_op("NCHW_VECT_C"),
np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)),
expected=np.array([[[[[0, 1, 2, 3, 16, 17, 18, 19]]],
[[[4, 5, 6, 7, 20, 21, 22, 23]]],
[[[8, 9, 10, 11, 24, 25, 26, 27]]],
[[[12, 13, 14, 15, 28, 29, 30, 31]]]]],
dtype=dtype))
expected=np.array(
[[[[[0, 1, 2, 3]]], [[[16, 17, 18, 19]]], [[[4, 5, 6, 7]]],
[[[20, 21, 22, 23]]], [[[8, 9, 10, 11]]], [[[24, 25, 26, 27]]],
[[[12, 13, 14, 15]]], [[[28, 29, 30, 31]]]]],
dtype=dtype))
def _assertSoftplusMatchesExpected(self,
features,

View File

@ -62,7 +62,7 @@ xla::StatusOr<xla::XlaOp> Expand(xla::XlaOp input, int64 dim) {
std::vector<int64> expanded_shape =
xla::SpanToVector(input_shape.dimensions());
expanded_shape[dim] /= 4;
expanded_shape.insert(expanded_shape.begin() + dim, 4);
expanded_shape.insert(expanded_shape.begin() + dim + 1, 4);
// Move the newly created dimension to the end with a transpose.
std::vector<int64> permutation;