diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index fcf8b4b4e9d..011eb07d3bd 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -336,7 +336,6 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
   // element_type is the data type for the binary operation.
   llvm::Type* element_type = output_address_type->getPointerElementType();
   int element_size = llvm_ir::GetSizeInBits(element_type);
-  llvm::Type* element_address_type = element_type->getPointerTo();
 
   int atomic_size = (element_size < 32) ? 32 : element_size;
   llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
@@ -346,10 +345,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
   // cas_old_output_address and cas_new_output_address point to the scratch
   // memory where we store the old and new values for the repeated atomicCAS
   // operations.
-  llvm::Value* cas_old_output_address =
-      Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
-  llvm::Value* cas_new_output_address =
-      Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
+  llvm::Value* cas_old_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
+      atomic_type, "cas_old_output_address", &b_);
+  llvm::Value* cas_new_output_address = llvm_ir::EmitAllocaAtFunctionEntry(
+      atomic_type, "cas_new_output_address", &b_);
 
   // Emit preparation code to the preheader.
   llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
@@ -372,11 +371,19 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
         IntToPtr(atomic_memory_address, atomic_address_type);
     binop_output_address =
         Add(PtrToInt(cas_new_output_address, address_int_type), offset);
-    binop_output_address = IntToPtr(binop_output_address, element_address_type);
+    binop_output_address = IntToPtr(
+        binop_output_address,
+        llvm::PointerType::get(
+            element_type,
+            cas_new_output_address->getType()->getPointerAddressSpace()));
   } else {
-    atomic_memory_address = BitCast(output_address, atomic_address_type);
-    binop_output_address =
-        BitCast(cas_new_output_address, element_address_type);
+    atomic_memory_address = b_.CreatePointerBitCastOrAddrSpaceCast(
+        output_address, atomic_address_type);
+    binop_output_address = b_.CreatePointerBitCastOrAddrSpaceCast(
+        cas_new_output_address,
+        llvm::PointerType::get(
+            element_type,
+            cas_new_output_address->getType()->getPointerAddressSpace()));
   }
 
   // Use the value from the memory that atomicCAS operates on to initialize