diff --git a/tensorflow/core/common_runtime/collective_util.cc b/tensorflow/core/common_runtime/collective_util.cc
index e6c16878e63..a94e6cb0a36 100644
--- a/tensorflow/core/common_runtime/collective_util.cc
+++ b/tensorflow/core/common_runtime/collective_util.cc
@@ -83,13 +83,11 @@ SubContext::SubContext(OpKernelContext* ctx, OpKernelContext::Params* params,
                        OpKernel* op, Tensor* output, Tensor* input)
     : sub_params_(*params),
       sub_inputs_({TensorValue(output), TensorValue(input)}),
-      sub_input_attr_({ctx->input_alloc_attr(0), ctx->input_alloc_attr(0)}),
-      sub_input_dc_(
-          {ctx->input_device_context(0), ctx->input_device_context(0)}) {
+      sub_input_attr_({ctx->input_alloc_attr(0), ctx->input_alloc_attr(0)}) {
   sub_params_.op_kernel = op;
   sub_params_.inputs = &sub_inputs_;
   sub_params_.input_alloc_attrs = &sub_input_attr_;
-  sub_params_.input_device_contexts = &sub_input_dc_;
+  sub_params_.op_device_context = ctx->op_device_context();
   sub_params_.eigen_gpu_device = nullptr;
   sub_params_.ensure_eigen_gpu_device();
   sub_params_.forward_from_array = &forward_from_;
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 95b806821a6..194aeb05f1a 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -251,15 +251,6 @@ Status KernelAndDeviceOp::Run(ScopedStepContainer* step_container,
                              tensorflow::HOST_MEMORY);
   }
 
-  gtl::InlinedVector<DeviceContext*, 4> input_device_contexts;
-  for (int i = 0; i < inputs.GetTensorValues()->size(); i++) {
-    DeviceContext* device_context = nullptr;
-    if (device_->tensorflow_gpu_device_info() != nullptr) {
-      device_context = device_->tensorflow_gpu_device_info()->default_context;
-    }
-    input_device_contexts.push_back(device_context);
-  }
-
   OpKernelContext::Params params;
   params.is_eager = true;
   params.device = device_;
@@ -296,7 +287,6 @@ Status KernelAndDeviceOp::Run(ScopedStepContainer* step_container,
   params.step_container = step_container;
   params.collective_executor =
       collective_executor_ ? collective_executor_->get() : nullptr;
-  params.input_device_contexts = &input_device_contexts;
 
   OpKernelContext context(&params);
 
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index f5afad7b2b6..4d738d881ba 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -275,7 +275,6 @@ struct NodeItem {
 };
 
 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
-typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
 
 // Immutable view of a Graph organized for efficient execution.
@@ -895,8 +894,7 @@ class ExecutorState {
           ref_mu(other.ref_mu),
           has_value(other.has_value),
           val_field_is_set(other.val_field_is_set),
-          alloc_attr(other.alloc_attr),
-          device_context(other.device_context) {
+          alloc_attr(other.alloc_attr) {
       if (val_field_is_set) {
         val.Init(*other.val);
       }
@@ -914,7 +912,6 @@ class ExecutorState {
       has_value = other.has_value;
       val_field_is_set = other.val_field_is_set;
       alloc_attr = other.alloc_attr;
-      device_context = other.device_context;
       if (val_field_is_set) {
         val.Init(*other.val);
       }
@@ -930,7 +927,6 @@ class ExecutorState {
       has_value = other.has_value;
       val_field_is_set = other.val_field_is_set;
       alloc_attr = other.alloc_attr;
-      device_context = other.device_context;
       if (val_field_is_set) {
         val.Init(std::move(*other.val));
       }
@@ -959,10 +955,6 @@ class ExecutorState {
 
     // The attributes of the allocator that creates the tensor.
     AllocatorAttributes alloc_attr;
-
-    // Every entry carries an optional DeviceContext containing
-    // Device-specific information about how the Tensor was produced.
-    DeviceContext* device_context = nullptr;
   };
 
   // Contains the device context assigned by the device at the beginning of a
@@ -1374,7 +1366,6 @@ class ExecutorState {
   // Before invoking item->kernel, fills in its "inputs".
   Status PrepareInputs(const NodeItem& item, Entry* first_input,
                        TensorValueVec* inputs,
-                       DeviceContextVec* input_device_contexts,
                        AllocatorAttributeVec* input_alloc_attrs,
                        bool* is_input_dead);
 
@@ -1600,9 +1591,8 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
 }
 
 // State kept alive for executing an asynchronous node in another
-// thread.  NOTE: We need to make a copy of p.input,
-// p.input_device_contexts, and p.input_alloc_attrs for asynchronous
-// kernels because OpKernelContext methods like input_type(i) needs
+// thread.  NOTE: We need to make a copy of p.input and p.input_alloc_attrs for
+// asynchronous kernels because OpKernelContext methods like input_type(i) needs
 // the param points to valid input type vector. It's not an issue for
 // sync kernels because these vectors are kept on the stack.
 struct ExecutorState::AsyncState {
@@ -1610,7 +1600,6 @@ struct ExecutorState::AsyncState {
              const NodeItem* _item, Entry* _first_input,
              NodeExecStatsInterface* _stats)
       : saved_inputs(*p.inputs),
-        saved_input_device_contexts(*p.input_device_contexts),
         saved_input_alloc_attrs(*p.input_alloc_attrs),
         params(p),
         tagged_node(_tagged_node),
@@ -1621,12 +1610,10 @@ struct ExecutorState::AsyncState {
         ctx(ParamsButClearingEigenGPUDevice(&params), item->num_outputs),
         stats(_stats) {
     params.inputs = &saved_inputs;
-    params.input_device_contexts = &saved_input_device_contexts;
     params.input_alloc_attrs = &saved_input_alloc_attrs;
   }
 
   TensorValueVec saved_inputs;
-  DeviceContextVec saved_input_device_contexts;
   AllocatorAttributeVec saved_input_alloc_attrs;
   OpKernelContext::Params params;
   TaggedNode tagged_node;
@@ -1682,7 +1669,6 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
 
   // Parameters passed to OpKernel::Compute.
   TensorValueVec inputs;
-  DeviceContextVec input_device_contexts;
   AllocatorAttributeVec input_alloc_attrs;
 
   OpKernelContext::Params params;
@@ -1710,7 +1696,6 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
   params.step_container = step_container_;
   params.slice_reader_cache = slice_reader_cache_;
   params.inputs = &inputs;
-  params.input_device_contexts = &input_device_contexts;
   params.input_alloc_attrs = &input_alloc_attrs;
   params.runner = &runner_;
   params.stats_collector = stats_collector_;
@@ -1790,8 +1775,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
     } else {
       // Prepares inputs.
       bool is_input_dead = false;
-      s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
-                        &input_alloc_attrs, &is_input_dead);
+      s = PrepareInputs(item, first_input, &inputs, &input_alloc_attrs,
+                        &is_input_dead);
       if (!s.ok()) {
         // Clear inputs.
         int num_inputs = item.num_inputs;
@@ -1961,13 +1946,10 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
 
 Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
                                     TensorValueVec* inputs,
-                                    DeviceContextVec* input_device_contexts,
                                     AllocatorAttributeVec* input_alloc_attrs,
                                     bool* is_input_dead) {
   inputs->clear();
   inputs->resize(item.num_inputs);
-  input_device_contexts->clear();
-  input_device_contexts->resize(item.num_inputs);
   input_alloc_attrs->clear();
   input_alloc_attrs->resize(item.num_inputs);
 
@@ -1977,7 +1959,6 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
   for (int i = 0; i < item.num_inputs; ++i) {
     const bool expect_ref = IsRefType(item.input_type(i));
     Entry* entry = first_input + i;
-    (*input_device_contexts)[i] = entry->device_context;
     (*input_alloc_attrs)[i] = entry->alloc_attr;
 
     // i-th input.
@@ -2084,9 +2065,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
     return s;
   }
 
-  // Get the device_context for this node id, if it exists.
-  DeviceContext* device_context = device_context_;
-
   for (int i = 0; i < item.num_outputs; ++i) {
     const TensorValue val = ctx->release_output(i);
     if (val.tensor == nullptr) {
@@ -2099,9 +2077,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
     } else {
       Entry* out = &((*outputs)[i]);
 
-      // Set the device context of the output entry.
-      out->device_context = device_context;
-
       // Set the allocator attributes of the output entry.
       out->alloc_attr = ctx->output_alloc_attr(i);
 
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
index 450876b99d6..966956dd5ae 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
@@ -441,7 +441,6 @@ class EMBenchmarkHelper {
 
     params->step_container = nullptr;
     params->slice_reader_cache = nullptr;
-    params->input_device_contexts = nullptr;
     params->resource_manager = gpu_helper_->gpu()->resource_manager();
 
     params->stats_collector = nullptr;
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index bd8ce352389..9fcf75fb4d3 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -644,7 +644,6 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
       gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
           {AllocatorAttributes()});
       op_params.input_alloc_attrs = &input_aa;
-      gtl::InlinedVector<DeviceContext*, 4> input_dc;
       DeviceContext* dev_ctx = nullptr;
       auto* dev_info = device_->tensorflow_gpu_device_info();
       if (dev_info) {
@@ -653,8 +652,6 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
       } else {
         dev_ctx = new DeviceContext;
       }
-      input_dc.push_back(dev_ctx);
-      op_params.input_device_contexts = &input_dc;
       op_params.op_device_context = dev_ctx;
       int forward_from[] = {OpKernelContext::Params::kNeverForward};
       if (forward_input) forward_from[0] = 0;
diff --git a/tensorflow/core/common_runtime/ring_gatherer.cc b/tensorflow/core/common_runtime/ring_gatherer.cc
index f13f6175255..db096ba5d92 100644
--- a/tensorflow/core/common_runtime/ring_gatherer.cc
+++ b/tensorflow/core/common_runtime/ring_gatherer.cc
@@ -105,7 +105,7 @@ void RingGatherer::Run(StatusCallback done) {
     Status status;
     Tensor alias_chunk(ca_->ChunkAlias(col_params_->subdiv_rank[0]));
     CollectiveRemoteAccessLocal::MemCpyAsync(
-        col_ctx_->op_ctx->input_device_context(0),
+        col_ctx_->op_ctx->op_device_context(),
         col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
         col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
         col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input, &alias_chunk,
diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc
index 0e3cbf7356d..87a493c39b9 100644
--- a/tensorflow/core/common_runtime/ring_gatherer_test.cc
+++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc
@@ -456,7 +456,6 @@ class RingGathererTest : public ::testing::Test {
       gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
           {AllocatorAttributes()});
       op_params.input_alloc_attrs = &input_aa;
-      gtl::InlinedVector<DeviceContext*, 4> input_dc;
       DeviceContext* dev_ctx = nullptr;
       auto* dev_info = device_->tensorflow_gpu_device_info();
       if (dev_info) {
@@ -465,8 +464,6 @@ class RingGathererTest : public ::testing::Test {
       } else {
         dev_ctx = new DeviceContext;
       }
-      input_dc.push_back(dev_ctx);
-      op_params.input_device_contexts = &input_dc;
       op_params.op_device_context = dev_ctx;
       AllocatorAttributes generic_alloc_attr;
       op_params.output_attr_array = &generic_alloc_attr;
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index 57cd14d708e..ecba1139433 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -92,7 +92,7 @@ void RingReducer::Run(StatusCallback done) {
     Status status;
     profiler::TraceMe activity("MemCpyAsync", profiler::TraceMeLevel::kInfo);
     CollectiveRemoteAccessLocal::MemCpyAsync(
-        col_ctx_->op_ctx->input_device_context(0),
+        col_ctx_->op_ctx->op_device_context(),
         col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
         col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
         col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index 921b3c97e29..9c82712623e 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -485,7 +485,6 @@ class RingReducerTest : public ::testing::Test {
       gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
           {AllocatorAttributes()});
       op_params.input_alloc_attrs = &input_aa;
-      gtl::InlinedVector<DeviceContext*, 4> input_dc;
       DeviceContext* dev_ctx = nullptr;
       auto* dev_info = device_->tensorflow_gpu_device_info();
       if (dev_info) {
@@ -494,8 +493,6 @@ class RingReducerTest : public ::testing::Test {
       } else {
         dev_ctx = new DeviceContext;
       }
-      input_dc.push_back(dev_ctx);
-      op_params.input_device_contexts = &input_dc;
       op_params.op_device_context = dev_ctx;
       int forward_from = 0;
       op_params.forward_from_array = &forward_from;
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 5d8741461b6..8372359e7ae 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -703,9 +703,7 @@ class OpKernelContext {
     const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
         nullptr;
 
-    // Device contexts.
-    const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts =
-        nullptr;
+    // Device context.
     DeviceContext* op_device_context = nullptr;
 
     // Control-flow op supports.
@@ -1060,18 +1058,6 @@ class OpKernelContext {
   // Returns nullptr if allocate_output() or set_output() have not been called.
   Status mutable_output(StringPiece name, Tensor** tensor);
 
-  // Records device specific state about how the input tensors were
-  // computed.
-  //
-  // If using the templated function, the type must be a subclass
-  // of DeviceContext.
-  //
-  // Get the DeviceContext used for the index input.  Returns nullptr
-  // if no DeviceContext was provided.
-  template <typename T>
-  T* input_device_context(int index);
-  DeviceContext* input_device_context(int index);
-
   // Return the DeviceContext that should be used for this Op.
   //
   // If using the templated function, the type must be a subclass
@@ -1705,23 +1691,6 @@ T* OpKernelContext::op_device_context() {
   return static_cast<T*>(op_device_context());
 }
 
-template <typename T>
-T* OpKernelContext::input_device_context(int index) {
-  DCHECK_NE(params_->input_device_contexts, nullptr);
-  DCHECK_GE(index, 0);
-  DCHECK_LT(index, params_->input_device_contexts->size());
-  static_assert(std::is_base_of<DeviceContext, T>::value,
-                "T is not a subclass of DeviceContext");
-  return static_cast<T*>((*params_->input_device_contexts)[index]);
-}
-
-inline DeviceContext* OpKernelContext::input_device_context(int index) {
-  DCHECK_NE(params_->input_device_contexts, nullptr);
-  DCHECK_GE(index, 0);
-  DCHECK_LT(index, params_->input_device_contexts->size());
-  return (*params_->input_device_contexts)[index];
-}
-
 inline const Tensor& OpInputList::operator[](int i) const {
   DCHECK_GE(i, 0);
   DCHECK_LT(i, stop_ - start_);
diff --git a/tensorflow/core/kernels/collective_nccl_test.cc b/tensorflow/core/kernels/collective_nccl_test.cc
index 8000f5386ae..669d7c3321d 100644
--- a/tensorflow/core/kernels/collective_nccl_test.cc
+++ b/tensorflow/core/kernels/collective_nccl_test.cc
@@ -303,9 +303,6 @@ class NcclTestBase : public ::testing::Test {
       gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
           {AllocatorAttributes()});
       op_params.input_alloc_attrs = &input_aa;
-      gtl::InlinedVector<DeviceContext*, 4> input_dc;
-      input_dc.push_back(op_params.op_device_context);
-      op_params.input_device_contexts = &input_dc;
       int forward_from = 0;
       op_params.forward_from_array = &forward_from;
       AllocatorAttributes generic_alloc_attr;
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc
index 56ee12e444d..d26d05ed202 100644
--- a/tensorflow/core/kernels/data/single_threaded_executor.cc
+++ b/tensorflow/core/kernels/data/single_threaded_executor.cc
@@ -26,7 +26,6 @@ namespace data {
 namespace {
 
 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
-typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
 
 class SingleThreadedExecutorImpl : public Executor {
@@ -198,7 +197,6 @@ class SingleThreadedExecutorImpl : public Executor {
     // TODO(mrry): Can we avoid copying into these vectors? Consider modifying
     // OpKernelContext to take the TensorValueVec as a pointer into `inputs`.
     TensorValueVec node_inputs;
-    DeviceContextVec input_device_contexts;
     AllocatorAttributeVec input_alloc_attrs;
 
     // Prepare the parameters that will be the same for all kernels.
@@ -222,7 +220,6 @@ class SingleThreadedExecutorImpl : public Executor {
     params.step_container = args.step_container;
     params.slice_reader_cache = nullptr;  // TODO(mrry): Too severe?
     params.inputs = &node_inputs;
-    params.input_device_contexts = &input_device_contexts;
     params.input_alloc_attrs = &input_alloc_attrs;
 
     Args::Runner runner_copy = args.runner;
@@ -257,8 +254,6 @@ class SingleThreadedExecutorImpl : public Executor {
         input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j];
       }
       params.op_kernel = kernel_state.kernel;
-      input_device_contexts.clear();
-      input_device_contexts.resize(num_inputs);
       params.output_attr_array = kernel_state.output_alloc_attrs.data();
       OpKernelContext ctx(&params, num_outputs);