diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index f7821adc74c..d444c1d49dd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -213,6 +213,12 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( const Shape& shape, const Shape& operand_shape, llvm::IRBuilder<>* builder) const { CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape)); + // In case the bitcast is just a reshape, we can use SourceIndexOfReshape() + // instead. This will reuse linear() if possible, so we don't have to build a + // new 'linear_index'. + if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) { + return SourceIndexOfReshape(shape, operand_shape, builder); + } // First linearize the index coming from the output of the bitcast. We want // the physical index of the element in the buffer. This is like Linearize,