diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index acd278d7a51..fbc5f17a915 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
 
 #include "tensorflow/core/kernels/transpose_op.h"
 
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/bounds_check.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -28,15 +29,43 @@ limitations under the License.
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/work_sharder.h"
 
 namespace tensorflow {
 
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+namespace {
+
+template <typename T>
+struct InvertPermutations {
+  static void Run(OpKernelContext* context, const Tensor& input, Tensor* out,
+                  int start, int limit) {
+    auto input_tensor = input.matrix<T>();
+    const T N = static_cast<T>(
+        input_tensor.dimension(1));  // Safe: bounds already checked.
+    auto output_tensor = out->matrix<T>();
+    for (int64 i = start; i < limit; ++i) {
+      for (int j = 0; j < N; ++j) {
+        const T d = internal::SubtleMustCopy(input_tensor(i, j));
+        OP_REQUIRES(context, FastBoundsCheck(d, N),
+                    errors::InvalidArgument(d, " is not between 0 and ", N));
+        OP_REQUIRES(context, output_tensor(i, d) == -1,
+                    errors::InvalidArgument(d, " is duplicated in the input."));
+        output_tensor(i, d) = j;
+      }
+    }
+  }
+};
+
+}  // namespace
+
 // inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of
 // integers 0, 1, ..., n - 1 and returns the inverted
 // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
 //
-// REQUIRES: input is a vector of int32 or int64.
 // REQUIRES: input is a permutation of 0, 1, ..., n-1.
+//
 
 template <typename T>
 class InvertPermutationOp : public OpKernel {
@@ -46,28 +75,46 @@ class InvertPermutationOp : public OpKernel {
 
   void Compute(OpKernelContext* context) override {
     const Tensor& input = context->input(0);
-    OP_REQUIRES(
-        context, TensorShapeUtils::IsVector(input.shape()),
-        errors::InvalidArgument("invert_permutation expects a 1D vector."));
-    auto Tin = input.vec<T>();
+    OP_REQUIRES(context, input.dims() > 0,
+                errors::InvalidArgument("Permutation must have at least rank 1 "
+                                        "but is rank ",
+                                        input.dims()));
+
+    const int64 perm_size = input.dim_size(input.dims() - 1);
     OP_REQUIRES(context,
-                FastBoundsCheck(Tin.size(), std::numeric_limits<int32>::max()),
+                FastBoundsCheck(perm_size, std::numeric_limits<int32>::max()),
                 errors::InvalidArgument("permutation of nonnegative int32s "
                                         "must have <= int32 max elements"));
-    const T N = static_cast<T>(Tin.size());  // Safe: bounds-checked above.
+    Tensor input_reshaped;
+    int64 batch_size = 1;
+    // The last dimension is the permutation dimension.
+    for (int i = 0; i < input.dims() - 1; ++i) {
+      batch_size *= input.shape().dim_size(i);
+    }
+    TensorShape batch_vectors = TensorShape({batch_size, perm_size});
+    // Note that we always have a batch size, including the scalar case.
+    OP_REQUIRES(context, input_reshaped.CopyFrom(input, batch_vectors),
+                errors::Internal("Failed to reshape In[0] from ",
+                                 input.shape().DebugString()));
+
     Tensor* output = nullptr;
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, input.shape(), &output));
-    auto Tout = output->vec<T>();
-    std::fill_n(Tout.data(), N, -1);
-    for (int i = 0; i < N; ++i) {
-      const T d = internal::SubtleMustCopy(Tin(i));
-      OP_REQUIRES(context, FastBoundsCheck(d, N),
-                  errors::InvalidArgument(d, " is not between 0 and ", N));
-      OP_REQUIRES(context, Tout(d) == -1,
-                  errors::InvalidArgument(d, " is duplicated in the input."));
-      Tout(d) = i;
-    }
+    output->flat<T>() = output->flat<T>().constant(T(-1));
+    Tensor output_reshaped;
+    OP_REQUIRES(context, output_reshaped.CopyFrom(*output, batch_vectors),
+                errors::Internal("Failed to reshape Output[0] from ",
+                                 output->shape().DebugString()));
+
+    const int64 cost_per_unit = perm_size;
+    // Parallelize over outer dimensions
+    auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+    Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
+          cost_per_unit,
+          [&context, &input_reshaped, &output_reshaped](int start, int limit) {
+            InvertPermutations<T>::Run(context, input_reshaped,
+                                       &output_reshaped, start, limit);
+          });
   }
 };
 
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 60efdcb7a73..602b51a46e2 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1391,7 +1391,7 @@ REGISTER_OP("InvertPermutation")
     .Attr("T: {int32, int64} = DT_INT32")
     .SetShapeFn([](InferenceContext* c) {
       ShapeHandle x;
-      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
+      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &x));
       c->set_output(0, x);
       return Status::OK();
     });
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 718a34c07e6..c4309f60039 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -399,9 +399,9 @@ TEST(ArrayOpsTest, UniqueWithCounts_ShapeFn) {
 
 TEST(ArrayOpsTest, InvertPermutation_ShapeFn) {
   ShapeInferenceTestOp op("InvertPermutation");
-  INFER_OK(op, "?", "[?]");
   INFER_OK(op, "[1]", "in0");
-  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[]");
+  INFER_OK(op, "[1,2,3]", "in0");
+  INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
 }
 
 TEST(ArrayOpsTest, PadD_ShapeFn) {
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index ce96ee4ad6d..31994d78f50 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import map_fn
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import sort_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
@@ -1351,14 +1352,40 @@ class PadTest(test_util.TensorFlowTestCase):
 
 class InvertPermutationTest(test_util.TensorFlowTestCase):
 
-  @test_util.run_deprecated_v1
   def testInvertPermutation(self):
     for dtype in [dtypes.int32, dtypes.int64]:
-      with self.cached_session(use_gpu=True):
-        x = constant_op.constant([3, 4, 0, 2, 1], dtype=dtype)
-        y = array_ops.invert_permutation(x)
-        self.assertAllEqual(y.get_shape(), [5])
-        self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1])
+      x = constant_op.constant([3, 4, 0, 2, 1], dtype=dtype)
+      y = array_ops.invert_permutation(x)
+      self.assertAllEqual(y.shape, [5])
+      self.assertAllEqual(self.evaluate(y), [2, 4, 3, 0, 1])
+
+  def testInvertPermutationCheckRank(self):
+    for dtype in [dtypes.int32, dtypes.int64]:
+      x = constant_op.constant(3, dtype=dtype)
+      with self.assertRaisesRegexp(Exception, "at least rank 1"):
+        self.evaluate(array_ops.invert_permutation(x))
+
+  def testInvertPermutationBatch(self):
+    for dtype in [dtypes.int32, dtypes.int64]:
+      x = constant_op.constant([[[3, 4, 0, 2, 1], [2, 3, 4, 0, 1]]],
+                               dtype=dtype)
+      y = array_ops.invert_permutation(x)
+      self.assertAllEqual(y.shape, [1, 2, 5])
+      self.assertAllEqual(
+          self.evaluate(y), [[[2, 4, 3, 0, 1], [3, 4, 0, 1, 2]]])
+
+  @test_util.run_deprecated_v1
+  def testInvertPermutationLargerBatch(self):
+    perm = np.array([np.random.permutation(20) for _ in range(10)],
+                    dtype=np.int32)
+
+    for dtype in [dtypes.int32, dtypes.int64]:
+      x = constant_op.constant(perm, dtype=dtype)
+      y = array_ops.invert_permutation(x)
+      # Argsort should be equivalent to invert permutation.
+      z = sort_ops.argsort(x, axis=-1)
+      self.assertAllEqual(y.shape, [10, 20])
+      self.assertAllEqual(self.evaluate(y), self.evaluate(z))
 
 
 class UnravelIndexTest(test_util.TensorFlowTestCase):