Enable test CompatibleUseLinearIndexWithReshape.

This requires adding a special case to SourceIndexOfBitcast if the bitcast is a
reshape.

PiperOrigin-RevId: 188324197
This commit is contained in:
A. Unique TensorFlower 2018-03-08 06:39:52 -08:00 committed by TensorFlower Gardener
parent 6d44c84bb2
commit ae03359f61

View File

@ -213,6 +213,12 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast(
const Shape& shape, const Shape& operand_shape,
llvm::IRBuilder<>* builder) const {
CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
// In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
// instead. This will reuse linear() if possible, so we don't have to build a
// new 'linear_index'.
if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
return SourceIndexOfReshape(shape, operand_shape, builder);
}
// First linearize the index coming from the output of the bitcast. We want
// the physical index of the element in the buffer. This is like Linearize,