diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 24626099106..9a34cd8c6ae 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -139,6 +139,16 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
   return Status::OK();
 }
 
+template <typename T>
+static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
+  Tensor linspace(DataTypeToEnum<T>::v(), shape);
+  auto linspace_flat = linspace.flat<T>();
+  for (int64 i = 0; i < depth; ++i) {
+    linspace_flat(i) = i;
+  }
+  return linspace;
+}
+
 xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
                               int axis) {
   return ArgMinMax(input, output_type, axis, /*is_min=*/false);
@@ -162,17 +172,33 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
   // Build a Tensor populated with values 0, 1, 2, ... depth.
   std::vector<int64> linspace_dims(output_dims, 1);
   linspace_dims[axis] = depth;
-  xla::PrimitiveType linspace_type;
-  TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &linspace_type));
-  auto linspace_shape = xla::ShapeUtil::MakeShape(linspace_type, linspace_dims);
-  xla::XlaOp linspace = xla::Iota(builder, linspace_shape, axis);
+  TensorShape linspace_shape(linspace_dims);
+  Tensor linspace;
+  switch (index_type) {
+    case DT_UINT8:
+      linspace = MakeLinspaceTensor<uint8>(linspace_shape, depth);
+      break;
+    case DT_INT32:
+      linspace = MakeLinspaceTensor<int32>(linspace_shape, depth);
+      break;
+    case DT_INT64:
+      linspace = MakeLinspaceTensor<int64>(linspace_shape, depth);
+      break;
+    default:
+      return errors::InvalidArgument("Invalid argument type ",
+                                     DataTypeString(index_type));
+  }
+
+  xla::BorrowingLiteral linspace_literal;
+  TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
 
   // Broadcast the linspace constant across the indices along the new axis,
   // and test equality at each position.
   std::vector<int64> broadcast_dims(indices_shape.dims());
   std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
   std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
-  xla::XlaOp one_hot_bool = xla::Eq(indices, linspace, broadcast_dims);
+  xla::XlaOp one_hot_bool = xla::Eq(
+      indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims);
 
   // Selects the user-provided off_value and on_value values.
   *one_hot = xla::Select(one_hot_bool,