parent
938a3b7779
commit
a2119c8189
@ -387,19 +387,6 @@ tf_xla_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_xla_py_test(
|
|
||||||
name = "reshape_op_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["reshape_op_test.py"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/compiler/tests:xla_test",
|
|
||||||
"//tensorflow/compiler/tf2xla/python:xla",
|
|
||||||
"//tensorflow/python:array_ops",
|
|
||||||
"//tensorflow/python:dtypes",
|
|
||||||
"@absl_py//absl/testing:parameterized",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "dynamic_stitch_test",
|
name = "dynamic_stitch_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -1,48 +0,0 @@
|
|||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Tests for slicing."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
from absl.testing import parameterized
|
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.platform import googletest
|
|
||||||
|
|
||||||
|
|
||||||
class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase):
|
|
||||||
|
|
||||||
@parameterized.named_parameters(('32_bit_index', dtypes.int32),
|
|
||||||
('64_bit_index', dtypes.int64))
|
|
||||||
def testBasic(self, index_dtype):
|
|
||||||
for dtype in self.numeric_types:
|
|
||||||
with self.test_session():
|
|
||||||
i = array_ops.placeholder(dtype, shape=[2, 3])
|
|
||||||
with self.test_scope():
|
|
||||||
shape = constant_op.constant([3, 2], dtype=index_dtype)
|
|
||||||
o = array_ops.reshape(i, shape)
|
|
||||||
params = {
|
|
||||||
i: [[1, 2, 3], [4, 5, 6]],
|
|
||||||
}
|
|
||||||
result = o.eval(feed_dict=params)
|
|
||||||
|
|
||||||
self.assertAllEqual([[1, 2], [3, 4], [5, 6]], result)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
googletest.main()
|
|
@ -41,8 +41,8 @@ class ReshapeOp : public XlaOpKernel {
|
|||||||
sizes_shape.DebugString()));
|
sizes_shape.DebugString()));
|
||||||
const int64 num_dims = sizes_shape.num_elements();
|
const int64 num_dims = sizes_shape.num_elements();
|
||||||
|
|
||||||
std::vector<int64> shape_input;
|
xla::Literal literal;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
|
||||||
|
|
||||||
// Compute the output shape. Determine product of specified
|
// Compute the output shape. Determine product of specified
|
||||||
// dimensions, and find the index of the unspecified one if there
|
// dimensions, and find the index of the unspecified one if there
|
||||||
@ -51,7 +51,7 @@ class ReshapeOp : public XlaOpKernel {
|
|||||||
int64 product = 1;
|
int64 product = 1;
|
||||||
int unknown_index = -1;
|
int unknown_index = -1;
|
||||||
for (int d = 0; d < num_dims; ++d) {
|
for (int d = 0; d < num_dims; ++d) {
|
||||||
const int32 size = shape_input[d];
|
const int32 size = literal.Get<int>({d});
|
||||||
if (size == -1) {
|
if (size == -1) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, unknown_index == -1,
|
ctx, unknown_index == -1,
|
||||||
|
Loading…
Reference in New Issue
Block a user