diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index fcf8b4b4e9d..011eb07d3bd 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -336,7 +336,6 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // element_type is the data type for the binary operation. llvm::Type* element_type = output_address_type->getPointerElementType(); int element_size = llvm_ir::GetSizeInBits(element_type); - llvm::Type* element_address_type = element_type->getPointerTo(); int atomic_size = (element_size < 32) ? 32 : element_size; llvm::Type* atomic_type = b_.getIntNTy(atomic_size); @@ -346,10 +345,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // cas_old_output_address and cas_new_output_address point to the scratch // memory where we store the old and new values for the repeated atomicCAS // operations. - llvm::Value* cas_old_output_address = - Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); - llvm::Value* cas_new_output_address = - Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); + llvm::Value* cas_old_output_address = llvm_ir::EmitAllocaAtFunctionEntry( + atomic_type, "cas_old_output_address", &b_); + llvm::Value* cas_new_output_address = llvm_ir::EmitAllocaAtFunctionEntry( + atomic_type, "cas_new_output_address", &b_); // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock(); @@ -372,11 +371,19 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, IntToPtr(atomic_memory_address, atomic_address_type); binop_output_address = Add(PtrToInt(cas_new_output_address, address_int_type), offset); - binop_output_address = IntToPtr(binop_output_address, element_address_type); + binop_output_address = IntToPtr( + binop_output_address, + llvm::PointerType::get( + element_type, + cas_new_output_address->getType()->getPointerAddressSpace())); } else { - atomic_memory_address = BitCast(output_address, atomic_address_type); - binop_output_address = - BitCast(cas_new_output_address, element_address_type); + atomic_memory_address = b_.CreatePointerBitCastOrAddrSpaceCast( + output_address, atomic_address_type); + binop_output_address = b_.CreatePointerBitCastOrAddrSpaceCast( + cas_new_output_address, + llvm::PointerType::get( + element_type, + cas_new_output_address->getType()->getPointerAddressSpace())); } // Use the value from the memory that atomicCAS operates on to initialize