Reapply "Allow DT_INT64 input shapes for ReshapeOp." with fix.

BEGIN_PUBLIC
Reapply "Allow DT_INT64 input shapes for ReshapeOp." with fix.
END_PUBLIC

*** Reason for rollback ***

Reapply with fix.

*** Original change description ***

Automated rollback of commit 9f59beb676

PiperOrigin-RevId: 209629073
This commit is contained in:
Sanjoy Das 2018-08-21 11:25:10 -07:00 committed by TensorFlower Gardener
parent 138bc155d2
commit d81e875dd6
3 changed files with 66 additions and 3 deletions

View File

@ -387,6 +387,19 @@ tf_xla_py_test(
], ],
) )
tf_xla_py_test(
name = "reshape_op_test",
size = "small",
srcs = ["reshape_op_test.py"],
deps = [
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"@absl_py//absl/testing:parameterized",
],
)
tf_xla_py_test( tf_xla_py_test(
name = "dynamic_stitch_test", name = "dynamic_stitch_test",
size = "small", size = "small",

View File

@ -0,0 +1,50 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for slicing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase):
@parameterized.named_parameters(('32_bit_index', dtypes.int32),
('64_bit_index', dtypes.int64))
def testBasic(self, index_dtype):
for dtype in self.numeric_types:
with self.test_session():
i = array_ops.placeholder(dtype, shape=[2, 3])
with self.test_scope():
shape = constant_op.constant([3, 2], dtype=index_dtype)
o = array_ops.reshape(i, shape)
params = {
i: [[1, 2, 3], [4, 5, 6]],
}
result = o.eval(feed_dict=params)
self.assertAllEqual([[1, 2], [3, 4], [5, 6]], result)
if __name__ == '__main__':
googletest.main()

View File

@ -41,8 +41,8 @@ class ReshapeOp : public XlaOpKernel {
sizes_shape.DebugString())); sizes_shape.DebugString()));
const int64 num_dims = sizes_shape.num_elements(); const int64 num_dims = sizes_shape.num_elements();
xla::Literal literal; std::vector<int64> shape_input;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
// Compute the output shape. Determine product of specified // Compute the output shape. Determine product of specified
// dimensions, and find the index of the unspecified one if there // dimensions, and find the index of the unspecified one if there
@ -51,7 +51,7 @@ class ReshapeOp : public XlaOpKernel {
int64 product = 1; int64 product = 1;
int unknown_index = -1; int unknown_index = -1;
for (int d = 0; d < num_dims; ++d) { for (int d = 0; d < num_dims; ++d) {
const int32 size = literal.Get<int>({d}); const int32 size = shape_input[d];
if (size == -1) { if (size == -1) {
OP_REQUIRES( OP_REQUIRES(
ctx, unknown_index == -1, ctx, unknown_index == -1,