diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index ded8aa38c23..e7b6b9c0019 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -3762,6 +3762,7 @@ tf_cc_tests(
         "//third_party/eigen3",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/synchronization",
+        "@com_google_absl//absl/types:optional",
         "@zlib_archive//:zlib",
     ],
 )
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
index e1a27d4f5a6..f65c395b4ac 100644
--- a/tensorflow/core/lib/core/threadpool.cc
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -17,6 +17,7 @@ limitations under the License.
 
 #define EIGEN_USE_THREADS
 
+#include "absl/types/optional.h"
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/lib/core/blocking_counter.h"
 #include "tensorflow/core/platform/context.h"
@@ -117,8 +118,8 @@ void ThreadPool::Schedule(std::function<void()> fn) {
   underlying_threadpool_->Schedule(std::move(fn));
 }
 
-int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
-    const int64 block_size, const int64 total) {
+int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling(
+    const int64 total, const int64 block_size) {
   if (block_size <= 0 || total <= 1 || total <= block_size ||
       NumThreads() == 1) {
     return 1;
@@ -126,13 +127,47 @@ int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
   return (total + block_size - 1) / block_size;
 }
 
-// This functionality is similar to parallelFor, except that reasoning about
-// the number of shards used is significantly easier.
+int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
+    const int64 block_size, const int64 total) {
+  return NumShardsUsedByFixedBlockSizeScheduling(total, block_size);
+}
+
+void ThreadPool::ParallelFor(int64 total,
+                             const SchedulingParams& scheduling_params,
+                             const std::function<void(int64, int64)>& fn) {
+  switch (scheduling_params.strategy()) {
+    case SchedulingStrategy::kAdaptive: {
+      if (scheduling_params.cost_per_unit().has_value()) {
+        ParallelFor(total, *scheduling_params.cost_per_unit(), fn);
+      }
+      break;
+    }
+    case SchedulingStrategy::kFixedBlockSize: {
+      if (scheduling_params.block_size().has_value()) {
+        ParallelForFixedBlockSizeScheduling(
+            total, *scheduling_params.block_size(), fn);
+      }
+      break;
+    }
+  }
+}
+
 void ThreadPool::TransformRangeConcurrently(
     const int64 block_size, const int64 total,
     const std::function<void(int64, int64)>& fn) {
+  ParallelFor(total,
+              SchedulingParams(SchedulingStrategy::kFixedBlockSize,
+                               absl::nullopt /* cost_per_unit */, block_size),
+              fn);
+}
+
+// This functionality is similar to parallelFor, except that reasoning about
+// the number of shards used is significantly easier.
+void ThreadPool::ParallelForFixedBlockSizeScheduling(
+    const int64 total, const int64 block_size,
+    const std::function<void(int64, int64)>& fn) {
   const int num_shards_used =
-      NumShardsUsedByTransformRangeConcurrently(block_size, total);
+      NumShardsUsedByFixedBlockSizeScheduling(total, block_size);
   if (num_shards_used == 1) {
     fn(0, total);
     return;
@@ -166,7 +201,7 @@ void ThreadPool::TransformRangeConcurrently(
 }
 
 void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
-                             std::function<void(int64, int64)> fn) {
+                             const std::function<void(int64, int64)>& fn) {
   CHECK_GE(total, 0);
   CHECK_EQ(total, (int64)(Eigen::Index)total);
   threadpool_device_->parallelFor(
@@ -193,6 +228,18 @@ void ThreadPool::ParallelForWithWorkerId(
                                   });
 }
 
+void ThreadPool::ParallelForWithWorkerId(
+    int64 total, const SchedulingParams& scheduling_params,
+    const std::function<void(int64, int64, int)>& fn) {
+  ParallelFor(total, scheduling_params, [this, &fn](int64 start, int64 limit) {
+    // We may use the current thread to do some work synchronously.
+    // When calling CurrentThreadId() from outside of the thread
+    // pool, we get -1, so we can shift every id up by 1.
+    int id = CurrentThreadId() + 1;
+    fn(start, limit, id);
+  });
+}
+
 int ThreadPool::NumThreads() const {
   return underlying_threadpool_->NumThreads();
 }
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
index 51aa83cc625..d168faef670 100644
--- a/tensorflow/core/lib/core/threadpool.h
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -19,6 +19,7 @@ limitations under the License.
 #include <functional>
 #include <memory>
 
+#include "absl/types/optional.h"
 #include "tensorflow/core/lib/core/threadpool_interface.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/macros.h"
@@ -40,6 +41,64 @@ struct EigenEnvironment;
 
 class ThreadPool {
  public:
+  // Scheduling strategies for ParallelFor. The strategy governs how the given
+  // units of work are distributed among the available threads in the
+  // threadpool.
+  enum class SchedulingStrategy {
+    // The Adaptive scheduling strategy adaptively chooses the shard sizes based
+    // on the cost of each unit of work, and the cost model of the underlying
+    // threadpool device.
+    //
+    // The 'cost_per_unit' is an estimate of the number of CPU cycles (or
+    // nanoseconds if not CPU-bound) to complete a unit of work. Overestimating
+    // creates too many shards and CPU time will be dominated by per-shard
+    // overhead, such as Context creation. Underestimating may not fully make
+    // use of the specified parallelism, and may also cause inefficiencies due
+    // to load balancing issues and stragglers.
+    kAdaptive,
+    // The Fixed Block Size scheduling strategy shards the given units of work
+    // into shards of fixed size. In case the total number of units is not
+    // evenly divisible by 'block_size', at most one of the shards may be of
+    // smaller size. The exact number of shards may be found by a call to
+    // NumShardsUsedByFixedBlockSizeScheduling.
+    //
+    // Each shard may be executed on a different thread in parallel, depending
+    // on the number of threads available in the pool. Note that when there
+    // aren't enough threads in the pool to achieve full parallelism, function
+    // calls will be automatically queued.
+    kFixedBlockSize
+  };
+
+  // Contains additional parameters for either the Adaptive or the Fixed Block
+  // Size scheduling strategy.
+  class SchedulingParams {
+   public:
+    explicit SchedulingParams(SchedulingStrategy strategy,
+                              absl::optional<int64> cost_per_unit,
+                              absl::optional<int64> block_size)
+        : strategy_(strategy),
+          cost_per_unit_(cost_per_unit),
+          block_size_(block_size) {}
+
+    SchedulingStrategy strategy() const { return strategy_; }
+    absl::optional<int64> cost_per_unit() const { return cost_per_unit_; }
+    absl::optional<int64> block_size() const { return block_size_; }
+
+   private:
+    // The underlying Scheduling Strategy for which this instance contains
+    // additional parameters.
+    SchedulingStrategy strategy_;
+
+    // The estimated cost per unit of work in number of CPU cycles (or
+    // nanoseconds if not CPU-bound). Only applicable for Adaptive scheduling
+    // strategy.
+    absl::optional<int64> cost_per_unit_;
+
+    // The block size of each shard. Only applicable for Fixed Block Size
+    // scheduling strategy.
+    absl::optional<int64> block_size_;
+  };
+
   // Constructs a pool that contains "num_threads" threads with specified
   // "name". env->StartThread() is used to create individual threads with the
   // given ThreadOptions. If "low_latency_hint" is true the thread pool
@@ -83,17 +142,15 @@ class ThreadPool {
       const std::vector<std::pair<unsigned, unsigned>>& partitions);
 
   void ScheduleWithHint(std::function<void()> fn, int start, int limit);
-  // Requires 0 < block_size <= total.
-  // Spawns k threads and calls fn(i*block_size, (i+1)*block_size) from the
-  // ith thread (i>=0). When (i+1)*block_size > total, fn(i*block_size, total)
-  // is called instead. k = NumShardsUsedByTransformRangeConcurrently(...).
-  // Note that when there aren't enough threads in the pool to achieve full
-  // parallelism, function calls will be automatically queued.
-  void TransformRangeConcurrently(const int64 block_size, const int64 total,
-                                  const std::function<void(int64, int64)>& fn);
+
+  // Returns the number of shards used by ParallelForFixedBlockSizeScheduling
+  // with these parameters.
+  int NumShardsUsedByFixedBlockSizeScheduling(const int64 total,
+                                              const int64 block_size);
 
   // Returns the number of threads spawned by calling TransformRangeConcurrently
   // with these parameters.
+  // Deprecated. Use NumShardsUsedByFixedBlockSizeScheduling.
   int NumShardsUsedByTransformRangeConcurrently(const int64 block_size,
                                                 const int64 total);
 
@@ -106,9 +163,20 @@ class ThreadPool {
   // if not CPU-bound) to complete a unit of work. Overestimating creates too
   // many shards and CPU time will be dominated by per-shard overhead, such as
   // Context creation. Underestimating may not fully make use of the specified
-  // parallelism.
+  // parallelism, and may also cause inefficiencies due to load balancing
+  // issues and stragglers.
   void ParallelFor(int64 total, int64 cost_per_unit,
-                   std::function<void(int64, int64)> fn);
+                   const std::function<void(int64, int64)>& fn);
+
+  // Similar to ParallelFor above, but takes the specified scheduling strategy
+  // into account.
+  void ParallelFor(int64 total, const SchedulingParams& scheduling_params,
+                   const std::function<void(int64, int64)>& fn);
+
+  // Same as ParallelFor with Fixed Block Size scheduling strategy.
+  // Deprecated. Prefer ParallelFor with a SchedulingStrategy argument.
+  void TransformRangeConcurrently(const int64 block_size, const int64 total,
+                                  const std::function<void(int64, int64)>& fn);
 
   // Shards the "total" units of work. For more details, see "ParallelFor".
   //
@@ -129,6 +197,12 @@ class ThreadPool {
       int64 total, int64 cost_per_unit,
       const std::function<void(int64, int64, int)>& fn);
 
+  // Similar to ParallelForWithWorkerId above, but takes the specified
+  // scheduling strategy into account.
+  void ParallelForWithWorkerId(
+      int64 total, const SchedulingParams& scheduling_params,
+      const std::function<void(int64, int64, int)>& fn);
+
   // Returns the number of threads in the pool.
   int NumThreads() const;
 
@@ -142,6 +216,17 @@ class ThreadPool {
   Eigen::ThreadPoolInterface* AsEigenThreadPool() const;
 
  private:
+  // Divides the work represented by the range [0, total) into k shards.
+  // Calls fn(i*block_size, (i+1)*block_size) from the ith shard (0 <= i < k).
+  // Each shard may be executed on a different thread in parallel, depending on
+  // the number of threads available in the pool.
+  // When (i+1)*block_size > total, fn(i*block_size, total) is called instead.
+  // Here, k = NumShardsUsedByFixedBlockSizeScheduling(total, block_size).
+  // Requires 0 < block_size <= total.
+  void ParallelForFixedBlockSizeScheduling(
+      const int64 total, const int64 block_size,
+      const std::function<void(int64, int64)>& fn);
+
   // underlying_threadpool_ is the user_threadpool if user_threadpool is
   // provided in the constructor. Otherwise it is the eigen_threadpool_.
   Eigen::ThreadPoolInterface* underlying_threadpool_;
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
index f972fb4fb47..911645f04f1 100644
--- a/tensorflow/core/lib/core/threadpool_test.cc
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
 
 #include "absl/synchronization/barrier.h"
 #include "absl/synchronization/blocking_counter.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/platform/context.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/mutex.h"
@@ -62,7 +63,59 @@ TEST(ThreadPool, DoWork) {
   }
 }
 
-void RunSharding(int64 block_size, int64 total, ThreadPool* threads) {
+void RunWithFixedBlockSize(int64 block_size, int64 total, ThreadPool* threads) {
+  mutex mu;
+  int64 num_shards = 0;
+  int64 num_done_work = 0;
+  std::vector<std::atomic<bool>> work(total);
+  for (int i = 0; i < total; i++) {
+    work[i] = false;
+  }
+  threads->ParallelFor(
+      total,
+      ThreadPool::SchedulingParams(
+          ThreadPool::SchedulingStrategy::kFixedBlockSize /* strategy */,
+          absl::nullopt /* cost_per_unit */, block_size /* block_size */),
+      [=, &mu, &num_shards, &num_done_work, &work](int64 start, int64 end) {
+        VLOG(1) << "Shard [" << start << "," << end << ")";
+        EXPECT_GE(start, 0);
+        EXPECT_LE(end, total);
+        mutex_lock l(mu);
+        ++num_shards;
+        for (; start < end; ++start) {
+          EXPECT_FALSE(work[start].exchange(true));  // No duplicate
+          ++num_done_work;
+        }
+      });
+  EXPECT_EQ(num_done_work, total);
+  for (int i = 0; i < total; i++) {
+    ASSERT_TRUE(work[i]);
+  }
+  const int64 num_workers = (total + block_size - 1) / block_size;
+  if (num_workers < threads->NumThreads()) {
+    // If the intention is to limit the parallelism explicitly, we'd
+    // better honor it. Ideally, even if per_thread_max_parallelism >
+    // num_workers, we should expect that Shard() implementation do
+    // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor
+    // tends to over-shard.
+    EXPECT_LE(num_shards, 1 + num_workers);
+  }
+}
+
+// Adapted from work_sharder_test.cc
+TEST(ThreadPoolTest, ParallelForFixedBlockSizeScheduling) {
+  ThreadPool threads(Env::Default(), "test", 16);
+  for (auto block_size : {1, 7, 10, 64, 100, 256, 1000, 9999}) {
+    for (auto diff : {0, 1, 11, 102, 1003, 10005, 1000007}) {
+      const int64 total = block_size + diff;
+      RunWithFixedBlockSize(block_size, total, &threads);
+    }
+  }
+}
+
+void RunWithFixedBlockSizeTransformRangeConcurrently(int64 block_size,
+                                                     int64 total,
+                                                     ThreadPool* threads) {
   mutex mu;
   int64 num_shards = 0;
   int64 num_done_work = 0;
@@ -83,7 +136,6 @@ void RunSharding(int64 block_size, int64 total, ThreadPool* threads) {
           ++num_done_work;
         }
       });
-  LOG(INFO) << block_size << " " << total;
   EXPECT_EQ(num_done_work, total);
   for (int i = 0; i < total; i++) {
     ASSERT_TRUE(work[i]);
@@ -100,18 +152,39 @@ void RunSharding(int64 block_size, int64 total, ThreadPool* threads) {
 }
 
 // Adapted from work_sharder_test.cc
-TEST(SparseUtilsTest, TransformRangeConcurrently) {
+TEST(ThreadPoolTest, TransformRangeConcurrently) {
   ThreadPool threads(Env::Default(), "test", 16);
   for (auto block_size : {1, 7, 10, 64, 100, 256, 1000, 9999}) {
     for (auto diff : {0, 1, 11, 102, 1003, 10005, 1000007}) {
       const int64 total = block_size + diff;
-      RunSharding(block_size, total, &threads);
+      RunWithFixedBlockSizeTransformRangeConcurrently(block_size, total,
+                                                      &threads);
     }
   }
 }
 
-TEST(SparseUtilsTest, NumShardsUsedByTransformRangeConcurrently) {
+TEST(ThreadPoolTest, NumShardsUsedByFixedBlockSizeScheduling) {
   ThreadPool threads(Env::Default(), "test", 16);
+
+  EXPECT_EQ(1, threads.NumShardsUsedByFixedBlockSizeScheduling(
+                   3 /* total */, 3 /* block_size */));
+  EXPECT_EQ(2, threads.NumShardsUsedByFixedBlockSizeScheduling(
+                   4 /* total */, 3 /* block_size */));
+  EXPECT_EQ(2, threads.NumShardsUsedByFixedBlockSizeScheduling(
+                   5 /* total */, 3 /* block_size */));
+  EXPECT_EQ(2, threads.NumShardsUsedByFixedBlockSizeScheduling(
+                   6 /* total */, 3 /* block_size */));
+  EXPECT_EQ(3, threads.NumShardsUsedByFixedBlockSizeScheduling(
+                   7 /* total */, 3 /* block_size */));
+  EXPECT_EQ(7, threads.NumShardsUsedByFixedBlockSizeScheduling(
+                   7 /* total */, 1 /* block_size */));
+  EXPECT_EQ(1, threads.NumShardsUsedByFixedBlockSizeScheduling(
+                   7 /* total */, 0 /* block_size */));
+}
+
+TEST(ThreadPoolTest, NumShardsUsedByTransformRangeConcurrently) {
+  ThreadPool threads(Env::Default(), "test", 16);
+
   EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
                    3 /* block_size */, 3 /* total */));
   EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
@@ -128,6 +201,63 @@ TEST(SparseUtilsTest, NumShardsUsedByTransformRangeConcurrently) {
                    0 /* block_size */, 7 /* total */));
 }
 
+void RunFixedBlockSizeShardingWithWorkerId(int64 block_size, int64 total,
+                                           ThreadPool* threads) {
+  mutex mu;
+  int64 num_done_work = 0;
+  std::vector<std::atomic<bool>> work(total);
+  for (int i = 0; i < total; i++) {
+    work[i] = false;
+  }
+  const int64 num_threads = threads->NumThreads();
+  std::vector<std::atomic<bool>> threads_running(num_threads + 1);
+  for (int i = 0; i < num_threads + 1; i++) {
+    threads_running[i] = false;
+  }
+
+  threads->ParallelForWithWorkerId(
+      total,
+      ThreadPool::SchedulingParams(
+          ThreadPool::SchedulingStrategy::kFixedBlockSize /* strategy */,
+          absl::nullopt /* cost_per_unit */, block_size /* block_size */),
+      [=, &mu, &num_done_work, &work, &threads_running](int64 start, int64 end,
+                                                        int id) {
+        VLOG(1) << "Shard [" << start << "," << end << ")";
+        EXPECT_GE(start, 0);
+        EXPECT_LE(end, total);
+
+        // Store true for the current thread, and assert that another thread
+        // is not running with the same id.
+        EXPECT_GE(id, 0);
+        EXPECT_LE(id, num_threads);
+        EXPECT_FALSE(threads_running[id].exchange(true));
+
+        mutex_lock l(mu);
+        for (; start < end; ++start) {
+          EXPECT_FALSE(work[start].exchange(true));  // No duplicate
+          ++num_done_work;
+        }
+        EXPECT_TRUE(threads_running[id].exchange(false));
+      });
+
+  EXPECT_EQ(num_done_work, total);
+  for (int i = 0; i < total; i++) {
+    EXPECT_TRUE(work[i]);
+  }
+}
+
+TEST(ThreadPoolTest, ParallelForFixedBlockSizeSchedulingWithWorkerId) {
+  for (int32 num_threads : {1, 2, 3, 9, 16, 31}) {
+    ThreadPool threads(Env::Default(), "test", num_threads);
+    for (int64 block_size : {1, 7, 10, 64, 100, 256, 1000}) {
+      for (int64 diff : {0, 1, 11, 102, 1003}) {
+        const int64 total = block_size + diff;
+        RunFixedBlockSizeShardingWithWorkerId(block_size, total, &threads);
+      }
+    }
+  }
+}
+
 TEST(ThreadPool, ParallelFor) {
   Context outer_context(ContextKind::kThread);
   // Make ParallelFor use as many threads as possible.
@@ -154,6 +284,36 @@ TEST(ThreadPool, ParallelFor) {
   }
 }
 
+TEST(ThreadPool, ParallelForWithAdaptiveSchedulingStrategy) {
+  Context outer_context(ContextKind::kThread);
+  // Make ParallelFor use as many threads as possible.
+  int64 kHugeCost = 1 << 30;
+  for (int num_threads = 1; num_threads < kNumThreads; num_threads++) {
+    fprintf(stderr, "Testing with %d threads\n", num_threads);
+    const int kWorkItems = 15;
+    std::atomic<bool> work[kWorkItems];
+    ThreadPool pool(Env::Default(), "test", num_threads);
+    for (int i = 0; i < kWorkItems; i++) {
+      work[i] = false;
+    }
+    pool.ParallelFor(
+        kWorkItems,
+        ThreadPool::SchedulingParams(
+            ThreadPool::SchedulingStrategy::kAdaptive /* strategy */,
+            kHugeCost /* cost_per_unit */, absl::nullopt /* block_size */),
+        [&outer_context, &work](int64 begin, int64 end) {
+          Context inner_context(ContextKind::kThread);
+          ASSERT_EQ(outer_context, inner_context);
+          for (int64 i = begin; i < end; ++i) {
+            ASSERT_FALSE(work[i].exchange(true));
+          }
+        });
+    for (int i = 0; i < kWorkItems; i++) {
+      ASSERT_TRUE(work[i]);
+    }
+  }
+}
+
 TEST(ThreadPool, ParallelForWithWorkerId) {
   // Make ParallelForWithWorkerId use as many threads as possible.
   int64 kHugeCost = 1 << 30;
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
index 74f0713a618..58808e3a636 100644
--- a/tensorflow/core/util/work_sharder.cc
+++ b/tensorflow/core/util/work_sharder.cc
@@ -45,13 +45,15 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
     workers->ParallelFor(total, cost_per_unit, work);
     return;
   }
-  Sharder::Do(total, cost_per_unit, work,
-              [&workers](Sharder::Closure c) { workers->Schedule(c); },
-              max_parallelism);
+  Sharder::Do(
+      total, cost_per_unit, work,
+      [&workers](Sharder::Closure c) { workers->Schedule(c); },
+      max_parallelism);
 }
 
-// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
-// to directly specify the shard size.
+// DEPRECATED: Prefer threadpool->ParallelFor with SchedulingStrategy, which
+// allows you to specify the strategy for choosing shard sizes, including using
+// a fixed shard size.
 void Sharder::Do(int64 total, int64 cost_per_unit, const Work& work,
                  const Runner& runner, int max_parallelism) {
   cost_per_unit = std::max(int64{1}, cost_per_unit);
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
index 9db85a54c6c..92d1dc698b1 100644
--- a/tensorflow/core/util/work_sharder.h
+++ b/tensorflow/core/util/work_sharder.h
@@ -23,9 +23,11 @@ limitations under the License.
 
 namespace tensorflow {
 
-// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
-// to directly specify the shard size. Use this function only if you want to
-// manually cap parallelism.
+// DEPRECATED: Prefer threadpool->ParallelFor with SchedulingStrategy, which
+// allows you to specify the strategy for choosing shard sizes, including using
+// a fixed shard size. Use this function only if you want to manually cap
+// parallelism.
+//
 // Shards the "total" unit of work assuming each unit of work having
 // roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
 // total - 1. Each shard contains 1 or more units of work and the