diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index e94aef641e3..6449a399573 100644
--- a/tensorflow/core/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -100,9 +100,9 @@ typedef Eigen::SyclDevice SYCLDevice;
 #endif  // TENSORFLOW_USE_SYCL
 
 // Concatenates 'inputs' into a single tensor along the zeroth dimension.
-// Requires that all elements of 'inputs' have element type T. Writes to the
-// op's output at position 'output_index', using 'context' for the allocation to
-// ensure proper device placement.
+// Requires that all elements of 'inputs' have element type T. Writes to
+// 'output' using 'context' for the allocation to ensure proper device
+// placement.
 template <typename T>
 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
               Tensor* output) {
@@ -157,6 +157,25 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
   return Status::OK();
 }
 
+// Same as 'Concat' above, but handles Tensor dtype deduction automatically.
+Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
+              Tensor* output) {
+  const DataType type = inputs[0].dtype();
+  Status concat_status;
+  switch (type) {
+#define CASE(type)                                         \
+  case DataTypeToEnum<type>::value:                        \
+    concat_status = Concat<type>(context, inputs, output); \
+    break;
+    TF_CALL_ALL_TYPES(CASE);
+#undef CASE
+    default:
+      concat_status = errors::InvalidArgument("Unsupported data type: ", type);
+      break;
+  }
+  return concat_status;
+}
+
 // The Split*() functions split 'input' with element type T into 'sizes.size()'
 // tensors along the zeroth dimension, with the ith split having zeroth-
 // dimension size 'sizes[i]'. They allocate the output tensors using 'context',
@@ -268,6 +287,25 @@ Status Split(OpKernelContext* context, const Tensor& input,
   return SplitCPU<T>(context, input, sizes, outputs);
 }
 
+// Same as 'Split' above, but handles Tensor dtype automatically.
+Status Split(OpKernelContext* context, const Tensor& input,
+             const gtl::ArraySlice<int64> sizes, std::vector<Tensor>* outputs) {
+  const DataType type = input.dtype();
+  Status split_status;
+  switch (type) {
+#define CASE(type)                                              \
+  case DataTypeToEnum<type>::value:                             \
+    split_status = Split<type>(context, input, sizes, outputs); \
+    break;
+    TF_CALL_ALL_TYPES(CASE);
+#undef CASE
+    default:
+      split_status = errors::InvalidArgument("Unsupported data type: ", type);
+      break;
+  }
+  return split_status;
+}
+
 // A class encapsulating the state and logic for batching tensors.
 class BatchResource : public ResourceBase {
  public:
@@ -449,22 +487,9 @@ class BatchResource : public ResourceBase {
         }
       }
 
-      const DataType type = to_concatenate[0].dtype();
-      Status concat_status;
       Tensor concatenated_tensor;
-      switch (type) {
-#define CASE(type)                                                   \
-  case DataTypeToEnum<type>::value:                                  \
-    concat_status =                                                  \
-        Concat<type>(context, to_concatenate, &concatenated_tensor); \
-    break;
-        TF_CALL_ALL_TYPES(CASE);
-#undef CASE
-        default:
-          concat_status =
-              errors::InvalidArgument("Unsupported data type: ", type);
-          break;
-      }
+      Status concat_status =
+          Concat(context, to_concatenate, &concatenated_tensor);
       TF_RETURN_IF_ERROR(concat_status);
       concatenated_tensors->push_back(concatenated_tensor);
     }
@@ -1001,17 +1026,7 @@ class UnbatchResource : public ResourceBase {
         batch_keys.push_back(batch_indices(i, 0));
       }
 
-      const DataType type = data_t.dtype();
-      switch (type) {
-#define CASE(type)                                                          \
-  case DataTypeToEnum<type>::value:                                         \
-    TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \
-    break;
-        TF_CALL_ALL_TYPES(CASE);
-#undef CASE
-        default:
-          return errors::InvalidArgument("Unsupported data type: ", type);
-      }
+      TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs));
     }
 
     // Critical section.