diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc
index 366e8d49228..90585c9d98a 100644
--- a/tensorflow/compiler/tf2xla/sharding_util.cc
+++ b/tensorflow/compiler/tf2xla/sharding_util.cc
@@ -80,6 +80,30 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
   return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
 }
 
+xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
+    const Edge& edge, int num_cores_per_replica) {
+  if (edge.src() == nullptr) {
+    return tensorflow::errors::InvalidArgument(
+        "Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
+  }
+  TF_ASSIGN_OR_RETURN(
+      absl::optional<xla::OpSharding> sharding,
+      ParseShardingFromDevice(*edge.src(), num_cores_per_replica));
+  if (sharding.has_value() &&
+      sharding.value().type() == xla::OpSharding::TUPLE) {
+    if (edge.src_output() < 0 ||
+        edge.src_output() >= sharding.value().tuple_shardings_size()) {
+      return tensorflow::errors::InvalidArgument(
+          "Tuple index out of bound: edge=", edge.DebugString(),
+          " sharding=", sharding->DebugString());
+    }
+    absl::optional<xla::OpSharding> subsharding =
+        sharding.value().tuple_shardings(edge.src_output());
+    return subsharding;
+  }
+  return sharding;
+}
+
 void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
   string device_name = src.assigned_device_name();
   if (device_name.empty()) {
diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h
index 196434826f9..07657c656d3 100644
--- a/tensorflow/compiler/tf2xla/sharding_util.h
+++ b/tensorflow/compiler/tf2xla/sharding_util.h
@@ -43,6 +43,9 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
     const NodeDef& node_def, int num_cores_per_replica);
 
+xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
+    const Edge& edge, int num_cores_per_replica);
+
 void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
 
 // Get sharding inforamtion from node.
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
index 075a1ec9069..5fdc74b79fc 100644
--- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
@@ -1813,7 +1813,8 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
     } else if (sharding->type() != xla::OpSharding::REPLICATED &&
                sharding->type() != xla::OpSharding::OTHER) {
       return tensorflow::errors::InvalidArgument(
-          "Unsupported argument sharding: ", sharding->DebugString());
+          "Unsupported argument sharding (for arg ", n->DebugString(),
+          "): ", sharding->DebugString());
     }
     if (assigned_core.has_value()) {
       args_device_selector.ReportDeviceAssigned(*assigned_core, i);
@@ -1855,7 +1856,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
 
     TF_ASSIGN_OR_RETURN(
         absl::optional<xla::OpSharding> sharding,
-        ParseShardingFromDevice(*edge->src(), num_cores_per_replica));
+        ParseShardingFromEdgeSource(*edge, num_cores_per_replica));
 
     if (partitioned_output_nodes.contains(i)) {
       Node* output_node = partitioned_output_nodes[i];
@@ -1883,7 +1884,9 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
       } else if (sharding.value().type() != xla::OpSharding::REPLICATED &&
                  sharding.value().type() != xla::OpSharding::OTHER) {
         return tensorflow::errors::InvalidArgument(
-            "Unsupported argument sharding: ", sharding->DebugString());
+            "Unsupported argument sharding for retval ",
+            retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ",
+            sharding->DebugString());
       }
     } else {
       if (use_spmd) {
@@ -2472,7 +2475,8 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
 
 Status DistributedTPURewritePass::BuildExecuteNodes(
     const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
-    const Node& replicate_node, const DataTypeVector& arg_types,
+    const Node& replicate_node, const std::vector<std::string>& arg_names,
+    const DataTypeVector& arg_types,
     const std::vector<InferredShape>& arg_shapes,
     const DataTypeVector& retval_types,
     const std::vector<xla::OpSharding>& arg_shardings,
@@ -2595,7 +2599,9 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
       }
     } else {
       return tensorflow::errors::InvalidArgument(
-          "Unsupported argument sharding: ", sharding.DebugString());
+          "Unsupported argument sharding for arg=", arg_names[i],
+          " shape=", arg_shapes[i].shape.DebugString(), ": ",
+          sharding.DebugString());
     }
   }
   std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
@@ -3922,8 +3928,8 @@ Status DistributedTPURewritePass::FingerprintFunctionLibrary(
 
   std::vector<VariableWrite> variable_writes;
   TF_RETURN_IF_ERROR(BuildExecuteNodes(
-      params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_types,
-      arg_shapes, retval_types, arg_sharding, retval_sharding,
+      params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names,
+      arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding,
       tf_device_assignment, compile_node, variable_reads,
       control_after_compilation, control_after, &variable_writes, graph));
   bool contains_resource_write_op =
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
index 1931b4ac80f..a9692cc0edb 100644
--- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
@@ -413,9 +413,10 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
   // * `num_cores_per_replica` is the number of cores which are dedicated to
   //    each replica.
   // * `replicate_node` is the original TPUReplicate node.
-  // * `arg_types` are the types of the arguments to the computation function
+  // * `arg_names` are the names of the arguments to the computation function
   //    passed as argument to TPUReplicate, including per-replica,
   //    broadcast, and variable arguments.
+  // * `arg_types` are the corresponding types of the arguments.
   // * `arg_shapes` are the corresponding shapes (and handle types/shapes, if
   //    applicable).
   // * `arg_shardings` and `retval_shardings` are mappings from
@@ -431,6 +432,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
   static Status BuildExecuteNodes(
       const ParameterInfo& params_info, int num_tasks,
       int num_cores_per_replica, const Node& replicate_node,
+      const std::vector<std::string>& arg_names,
       const DataTypeVector& arg_types,
       const std::vector<InferredShape>& arg_shapes,
       const DataTypeVector& retval_types,
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index 6f74123131f..1336f52ed34 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -71,7 +71,6 @@ cc_library(
         "//tensorflow/core/tpu:tpu_api",
         "//tensorflow/core/tpu:tpu_configuration",
         "//tensorflow/core/tpu:tpu_defs",
-        "//tensorflow/stream_executor/tpu:status_helper",
         "//tensorflow/stream_executor/tpu:tpu_platform_interface",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
index 8bd45db2206..ce18e844e66 100644
--- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
@@ -117,7 +117,7 @@ Status SetPerCoreArgShapes(
   } else {
     TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED)
         << "Unsupported argument sharding: "
-        << proto_arg.sharding().DebugString();
+        << " proto_arg=" << proto_arg.DebugString();
     for (int core = 0; core < per_core_arg_shapes->size(); ++core) {
       (*arg_core_mapping)[arg_index].indices.push_back(
           (*per_core_arg_shapes)[core].size());