diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index fbc5f17a915..acd278d7a51 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/kernels/transpose_op.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -29,43 +28,15 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; - -namespace { - -template <typename T> -struct InvertPermutations { - static void Run(OpKernelContext* context, const Tensor& input, Tensor* out, - int start, int limit) { - auto input_tensor = input.matrix<T>(); - const T N = static_cast<T>( - input_tensor.dimension(1)); // Safe: bounds already checked. - auto output_tensor = out->matrix<T>(); - for (int64 i = start; i < limit; ++i) { - for (int j = 0; j < N; ++j) { - const T d = internal::SubtleMustCopy(input_tensor(i, j)); - OP_REQUIRES(context, FastBoundsCheck(d, N), - errors::InvalidArgument(d, " is not between 0 and ", N)); - OP_REQUIRES(context, output_tensor(i, d) == -1, - errors::InvalidArgument(d, " is duplicated in the input.")); - output_tensor(i, d) = j; - } - } - } -}; - -} // namespace - // inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of // integers 0, 1, ..., n - 1 and returns the inverted // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). // +// REQUIRES: input is a vector of int32 or int64. // REQUIRES: input is a permutation of 0, 1, ..., n-1. -// template <typename T> class InvertPermutationOp : public OpKernel { @@ -75,46 +46,28 @@ class InvertPermutationOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); - OP_REQUIRES(context, input.dims() > 0, - errors::InvalidArgument("Permutation must have at least rank 1 " - "but is rank ", - input.dims())); - - const int64 perm_size = input.dim_size(input.dims() - 1); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(input.shape()), + errors::InvalidArgument("invert_permutation expects a 1D vector.")); + auto Tin = input.vec<T>(); OP_REQUIRES(context, - FastBoundsCheck(perm_size, std::numeric_limits<int32>::max()), + FastBoundsCheck(Tin.size(), std::numeric_limits<int32>::max()), errors::InvalidArgument("permutation of nonnegative int32s " "must have <= int32 max elements")); - Tensor input_reshaped; - int64 batch_size = 1; - // The last dimension is the permutation dimension. - for (int i = 0; i < input.dims() - 1; ++i) { - batch_size *= input.shape().dim_size(i); - } - TensorShape batch_vectors = TensorShape({batch_size, perm_size}); - // Note that we always have a batch size, including the scalar case. - OP_REQUIRES(context, input_reshaped.CopyFrom(input, batch_vectors), - errors::Internal("Failed to reshape In[0] from ", - input.shape().DebugString())); - + const T N = static_cast<T>(Tin.size()); // Safe: bounds-checked above. Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); - output->flat<T>() = output->flat<T>().constant(T(-1)); - Tensor output_reshaped; - OP_REQUIRES(context, output_reshaped.CopyFrom(*output, batch_vectors), - errors::Internal("Failed to reshape Output[0] from ", - output->shape().DebugString())); - - const int64 cost_per_unit = perm_size; - // Parallelize over outer dimensions - auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); - Shard(worker_threads.num_threads, worker_threads.workers, batch_size, - cost_per_unit, - [&context, &input_reshaped, &output_reshaped](int start, int limit) { - InvertPermutations<T>::Run(context, input_reshaped, - &output_reshaped, start, limit); - }); + auto Tout = output->vec<T>(); + std::fill_n(Tout.data(), N, -1); + for (int i = 0; i < N; ++i) { + const T d = internal::SubtleMustCopy(Tin(i)); + OP_REQUIRES(context, FastBoundsCheck(d, N), + errors::InvalidArgument(d, " is not between 0 and ", N)); + OP_REQUIRES(context, Tout(d) == -1, + errors::InvalidArgument(d, " is duplicated in the input.")); + Tout(d) = i; + } } }; diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 602b51a46e2..60efdcb7a73 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1391,7 +1391,7 @@ REGISTER_OP("InvertPermutation") .Attr("T: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle x; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &x)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x)); c->set_output(0, x); return Status::OK(); }); diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index c4309f60039..718a34c07e6 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -399,9 +399,9 @@ TEST(ArrayOpsTest, UniqueWithCounts_ShapeFn) { TEST(ArrayOpsTest, InvertPermutation_ShapeFn) { ShapeInferenceTestOp op("InvertPermutation"); + INFER_OK(op, "?", "[?]"); INFER_OK(op, "[1]", "in0"); - INFER_OK(op, "[1,2,3]", "in0"); - INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]"); + INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[]"); } TEST(ArrayOpsTest, PadD_ShapeFn) { diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 31994d78f50..ce96ee4ad6d 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -44,7 +44,6 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import sort_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -1352,40 +1351,14 @@ class PadTest(test_util.TensorFlowTestCase): class InvertPermutationTest(test_util.TensorFlowTestCase): + @test_util.run_deprecated_v1 def testInvertPermutation(self): for dtype in [dtypes.int32, dtypes.int64]: - x = constant_op.constant([3, 4, 0, 2, 1], dtype=dtype) - y = array_ops.invert_permutation(x) - self.assertAllEqual(y.shape, [5]) - self.assertAllEqual(self.evaluate(y), [2, 4, 3, 0, 1]) - - def testInvertPermutationCheckRank(self): - for dtype in [dtypes.int32, dtypes.int64]: - x = constant_op.constant(3, dtype=dtype) - with self.assertRaisesRegexp(Exception, "at least rank 1"): - self.evaluate(array_ops.invert_permutation(x)) - - def testInvertPermutationBatch(self): - for dtype in [dtypes.int32, dtypes.int64]: - x = constant_op.constant([[[3, 4, 0, 2, 1], [2, 3, 4, 0, 1]]], - dtype=dtype) - y = array_ops.invert_permutation(x) - self.assertAllEqual(y.shape, [1, 2, 5]) - self.assertAllEqual( - self.evaluate(y), [[[2, 4, 3, 0, 1], [3, 4, 0, 1, 2]]]) - - @test_util.run_deprecated_v1 - def testInvertPermutationLargerBatch(self): - perm = np.array([np.random.permutation(20) for _ in range(10)], - dtype=np.int32) - - for dtype in [dtypes.int32, dtypes.int64]: - x = constant_op.constant(perm, dtype=dtype) - y = array_ops.invert_permutation(x) - # Argsort should be equivalent to invert permutation. - z = sort_ops.argsort(x, axis=-1) - self.assertAllEqual(y.shape, [10, 20]) - self.assertAllEqual(self.evaluate(y), self.evaluate(z)) + with self.cached_session(use_gpu=True): + x = constant_op.constant([3, 4, 0, 2, 1], dtype=dtype) + y = array_ops.invert_permutation(x) + self.assertAllEqual(y.get_shape(), [5]) + self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1]) class UnravelIndexTest(test_util.TensorFlowTestCase):