From 27f7de6d83f6580361973528a9a2922b869972ed Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Tue, 18 Aug 2020 12:30:14 -0700 Subject: [PATCH] Fix NCHWToNCHW_VECT_C data formation conversion in tf2xla bridge Newly created dimension should be inserted after the old dim so that it can be moved at the end with the transform permutation. Results after this fix matches the inferred shape. PiperOrigin-RevId: 327283205 Change-Id: Iead28dbd32b9c78652cd1348ce683c38c8199d83 --- tensorflow/compiler/tests/unary_ops_test.py | 18 +++++++++--------- tensorflow/compiler/tf2xla/lib/data_format.cc | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index f0ac86d5444..3a678d8ea11 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -1118,10 +1118,10 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( make_op("NCHW_VECT_C"), np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)), - expected=np.array([[[[[0, 1], [8, 9]], [[16, 17], [24, 25]]], - [[[2, 3], [10, 11]], [[18, 19], [26, 27]]], - [[[4, 5], [12, 13]], [[20, 21], [28, 29]]], - [[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]], + expected=np.array([[[[[0, 1, 2, 3], [8, 9, 10, 11]], + [[16, 17, 18, 19], [24, 25, 26, 27]]], + [[[4, 5, 6, 7], [12, 13, 14, 15]], + [[20, 21, 22, 23], [28, 29, 30, 31]]]]], dtype=dtype)) @test_util.disable_mlir_bridge( @@ -1172,11 +1172,11 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( make_op("NCHW_VECT_C"), np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)), - expected=np.array([[[[[0, 1, 2, 3, 16, 17, 18, 19]]], - [[[4, 5, 6, 7, 20, 21, 22, 23]]], - [[[8, 9, 10, 11, 24, 25, 26, 27]]], - [[[12, 13, 14, 15, 28, 29, 30, 31]]]]], - dtype=dtype)) + expected=np.array( + [[[[[0, 1, 2, 3]]], [[[16, 17, 18, 19]]], [[[4, 5, 6, 7]]], + [[[20, 21, 22, 23]]], [[[8, 9, 10, 11]]], [[[24, 25, 26, 27]]], + [[[12, 13, 14, 15]]], [[[28, 29, 30, 31]]]]], + dtype=dtype)) def _assertSoftplusMatchesExpected(self, features, diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index e5913a8bbf3..eb1ab79d165 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -62,7 +62,7 @@ xla::StatusOr Expand(xla::XlaOp input, int64 dim) { std::vector expanded_shape = xla::SpanToVector(input_shape.dimensions()); expanded_shape[dim] /= 4; - expanded_shape.insert(expanded_shape.begin() + dim, 4); + expanded_shape.insert(expanded_shape.begin() + dim + 1, 4); // Move the newly created dimension to the end with a transpose. std::vector permutation;