Fix NCHWToNCHW_VECT_C data formation conversion in tf2xla bridge
Newly created dimension should be inserted after the old dim so that it can be moved at the end with the transform permutation. Results after this fix matches the inferred shape. PiperOrigin-RevId: 327283205 Change-Id: Iead28dbd32b9c78652cd1348ce683c38c8199d83
This commit is contained in:
parent
7f01242aa1
commit
27f7de6d83
@ -1118,10 +1118,10 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
|||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
make_op("NCHW_VECT_C"),
|
make_op("NCHW_VECT_C"),
|
||||||
np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)),
|
np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)),
|
||||||
expected=np.array([[[[[0, 1], [8, 9]], [[16, 17], [24, 25]]],
|
expected=np.array([[[[[0, 1, 2, 3], [8, 9, 10, 11]],
|
||||||
[[[2, 3], [10, 11]], [[18, 19], [26, 27]]],
|
[[16, 17, 18, 19], [24, 25, 26, 27]]],
|
||||||
[[[4, 5], [12, 13]], [[20, 21], [28, 29]]],
|
[[[4, 5, 6, 7], [12, 13, 14, 15]],
|
||||||
[[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]],
|
[[20, 21, 22, 23], [28, 29, 30, 31]]]]],
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge(
|
@test_util.disable_mlir_bridge(
|
||||||
@ -1172,11 +1172,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
|||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
make_op("NCHW_VECT_C"),
|
make_op("NCHW_VECT_C"),
|
||||||
np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)),
|
np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)),
|
||||||
expected=np.array([[[[[0, 1, 2, 3, 16, 17, 18, 19]]],
|
expected=np.array(
|
||||||
[[[4, 5, 6, 7, 20, 21, 22, 23]]],
|
[[[[[0, 1, 2, 3]]], [[[16, 17, 18, 19]]], [[[4, 5, 6, 7]]],
|
||||||
[[[8, 9, 10, 11, 24, 25, 26, 27]]],
|
[[[20, 21, 22, 23]]], [[[8, 9, 10, 11]]], [[[24, 25, 26, 27]]],
|
||||||
[[[12, 13, 14, 15, 28, 29, 30, 31]]]]],
|
[[[12, 13, 14, 15]]], [[[28, 29, 30, 31]]]]],
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
def _assertSoftplusMatchesExpected(self,
|
def _assertSoftplusMatchesExpected(self,
|
||||||
features,
|
features,
|
||||||
|
@ -62,7 +62,7 @@ xla::StatusOr<xla::XlaOp> Expand(xla::XlaOp input, int64 dim) {
|
|||||||
std::vector<int64> expanded_shape =
|
std::vector<int64> expanded_shape =
|
||||||
xla::SpanToVector(input_shape.dimensions());
|
xla::SpanToVector(input_shape.dimensions());
|
||||||
expanded_shape[dim] /= 4;
|
expanded_shape[dim] /= 4;
|
||||||
expanded_shape.insert(expanded_shape.begin() + dim, 4);
|
expanded_shape.insert(expanded_shape.begin() + dim + 1, 4);
|
||||||
|
|
||||||
// Move the newly created dimension to the end with a transpose.
|
// Move the newly created dimension to the end with a transpose.
|
||||||
std::vector<int64> permutation;
|
std::vector<int64> permutation;
|
||||||
|
Loading…
Reference in New Issue
Block a user