diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index ae98b3f0f9d..47311d26301 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -387,6 +387,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "reshape_op_test", + size = "small", + srcs = ["reshape_op_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "@absl_py//absl/testing:parameterized", + ], +) + tf_xla_py_test( name = "dynamic_stitch_test", size = "small", diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py new file mode 100644 index 00000000000..8aa312cbc1e --- /dev/null +++ b/tensorflow/compiler/tests/reshape_op_test.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slicing.""" + +from __future__ import absolute_import + +from absl.testing import parameterized + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase): + + @parameterized.named_parameters(('32_bit_index', dtypes.int32), + ('64_bit_index', dtypes.int64)) + def testBasic(self, index_dtype): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[2, 3]) + with self.test_scope(): + shape = constant_op.constant([3, 2], dtype=index_dtype) + o = array_ops.reshape(i, shape) + params = { + i: [[1, 2, 3], [4, 5, 6]], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[1, 2], [3, 4], [5, 6]], result) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 121750a82a8..366ce42866e 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -41,8 +41,8 @@ class ReshapeOp : public XlaOpKernel { sizes_shape.DebugString())); const int64 num_dims = sizes_shape.num_elements(); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); + std::vector shape_input; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input)); // Compute the output shape. Determine product of specified // dimensions, and find the index of the unspecified one if there @@ -51,7 +51,7 @@ class ReshapeOp : public XlaOpKernel { int64 product = 1; int unknown_index = -1; for (int d = 0; d < num_dims; ++d) { - const int32 size = literal.Get({d}); + const int32 size = shape_input[d]; if (size == -1) { OP_REQUIRES( ctx, unknown_index == -1,