Refactor ElementalIrEmitter's slice index finding code into

IrArray::Index::SourceIndexOfSlice().

PiperOrigin-RevId: 161140653
This commit is contained in:
A. Unique TensorFlower 2017-07-06 15:36:56 -07:00 committed by TensorFlower Gardener
parent ba297aec99
commit f9c9cacb06
3 changed files with 35 additions and 17 deletions

View File

@ -948,23 +948,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
IrArray::Index sliced_index(index.size());
for (int i = 0; i < index.size(); ++i) {
int64 stride = hlo->slice_stride(i);
if (stride != 1) {
sliced_index[i] = ir_builder_->CreateAdd(
ir_builder_->CreateMul(
index[i], llvm::ConstantInt::get(index[i]->getType(),
stride)),
llvm::ConstantInt::get(index[i]->getType(),
hlo->slice_starts(i)));
} else {
sliced_index[i] = ir_builder_->CreateAdd(
index[i],
llvm::ConstantInt::get(index[i]->getType(),
hlo->slice_starts(i)));
}
}
IrArray::Index sliced_index = index.SourceIndexOfSlice(
/*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(),
/*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_);
return operand_to_generator.at(hlo->operand(0))(sliced_index);
};
case HloOpcode::kDynamicSlice:

View File

@ -153,6 +153,28 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
return Index(source_multidim_index);
}
IrArray::Index IrArray::Index::SourceIndexOfSlice(
const Shape& shape, tensorflow::gtl::ArraySlice<int64> starts,
tensorflow::gtl::ArraySlice<int64> strides,
llvm::IRBuilder<>* builder) const {
Index source_index(multidim_.size());
for (int i = 0; i < multidim_.size(); ++i) {
int64 stride = strides[i];
auto type = multidim_[i]->getType();
if (stride != 1) {
source_index[i] = builder->CreateAdd(
builder->CreateMul(multidim_[i],
llvm::ConstantInt::get(type, stride)),
llvm::ConstantInt::get(type, starts[i]));
} else {
source_index[i] = builder->CreateAdd(
multidim_[i], llvm::ConstantInt::get(type, starts[i]));
}
}
return source_index;
}
IrArray::Index IrArray::Index::SourceIndexOfTranspose(
const Shape& shape, const Shape& operand_shape,
tensorflow::gtl::ArraySlice<int64> dimension_mapping,

View File

@ -115,6 +115,16 @@ class IrArray {
Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape,
llvm::IRBuilder<>* builder) const;
// Returns the index into the source operand from which a slice operation
// selects a value to be placed into index "this". The slice is described
// by starting indices `starts` and stride values `strides`.
//
// Precondition: "this" is an index into a slice whose shape is `shape`.
Index SourceIndexOfSlice(const Shape& shape,
tensorflow::gtl::ArraySlice<int64> starts,
tensorflow::gtl::ArraySlice<int64> strides,
llvm::IRBuilder<>* builder) const;
// Given that "this" is the target index of a transpose from `operand_shape`
// to `shape` with the given dimension mapping, returns the source index.
Index SourceIndexOfTranspose(