From a2119c81894e99160978a444f2e8d9431d0f7abb Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Tue, 21 Aug 2018 10:49:16 -0700 Subject: [PATCH] Automated rollback of commit 9f59beb67643953d87e7673fa0000cc775562693 PiperOrigin-RevId: 209621853 --- tensorflow/compiler/tests/BUILD | 13 ----- tensorflow/compiler/tests/reshape_op_test.py | 48 ------------------- .../compiler/tf2xla/kernels/reshape_op.cc | 6 +-- 3 files changed, 3 insertions(+), 64 deletions(-) delete mode 100644 tensorflow/compiler/tests/reshape_op_test.py diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 47311d26301..ae98b3f0f9d 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -387,19 +387,6 @@ tf_xla_py_test( ], ) -tf_xla_py_test( - name = "reshape_op_test", - size = "small", - srcs = ["reshape_op_test.py"], - deps = [ - "//tensorflow/compiler/tests:xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "@absl_py//absl/testing:parameterized", - ], -) - tf_xla_py_test( name = "dynamic_stitch_test", size = "small", diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py deleted file mode 100644 index 8aa312cbc1e..00000000000 --- a/tensorflow/compiler/tests/reshape_op_test.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for slicing.""" - -from __future__ import absolute_import - -from absl.testing import parameterized - -from tensorflow.compiler.tests import xla_test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import googletest - - -class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase): - - @parameterized.named_parameters(('32_bit_index', dtypes.int32), - ('64_bit_index', dtypes.int64)) - def testBasic(self, index_dtype): - for dtype in self.numeric_types: - with self.test_session(): - i = array_ops.placeholder(dtype, shape=[2, 3]) - with self.test_scope(): - shape = constant_op.constant([3, 2], dtype=index_dtype) - o = array_ops.reshape(i, shape) - params = { - i: [[1, 2, 3], [4, 5, 6]], - } - result = o.eval(feed_dict=params) - - self.assertAllEqual([[1, 2], [3, 4], [5, 6]], result) - - -if __name__ == '__main__': - googletest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 366ce42866e..121750a82a8 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -41,8 +41,8 @@ class ReshapeOp : public XlaOpKernel { sizes_shape.DebugString())); const int64 num_dims = sizes_shape.num_elements(); - std::vector shape_input; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input)); + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); // Compute the output shape. Determine product of specified // dimensions, and find the index of the unspecified one if there @@ -51,7 +51,7 @@ class ReshapeOp : public XlaOpKernel { int64 product = 1; int unknown_index = -1; for (int d = 0; d < num_dims; ++d) { - const int32 size = shape_input[d]; + const int32 size = literal.Get({d}); if (size == -1) { OP_REQUIRES( ctx, unknown_index == -1,