Refactor ElementalIrEmitter's slice index finding code into
IrArray::Index::SourceIndexOfSlice(). PiperOrigin-RevId: 161140653
This commit is contained in:
parent
ba297aec99
commit
f9c9cacb06
@ -948,23 +948,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
case HloOpcode::kSlice:
|
||||
return [this, hlo, &operand_to_generator](
|
||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
IrArray::Index sliced_index(index.size());
|
||||
for (int i = 0; i < index.size(); ++i) {
|
||||
int64 stride = hlo->slice_stride(i);
|
||||
if (stride != 1) {
|
||||
sliced_index[i] = ir_builder_->CreateAdd(
|
||||
ir_builder_->CreateMul(
|
||||
index[i], llvm::ConstantInt::get(index[i]->getType(),
|
||||
stride)),
|
||||
llvm::ConstantInt::get(index[i]->getType(),
|
||||
hlo->slice_starts(i)));
|
||||
} else {
|
||||
sliced_index[i] = ir_builder_->CreateAdd(
|
||||
index[i],
|
||||
llvm::ConstantInt::get(index[i]->getType(),
|
||||
hlo->slice_starts(i)));
|
||||
}
|
||||
}
|
||||
IrArray::Index sliced_index = index.SourceIndexOfSlice(
|
||||
/*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(),
|
||||
/*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_);
|
||||
return operand_to_generator.at(hlo->operand(0))(sliced_index);
|
||||
};
|
||||
case HloOpcode::kDynamicSlice:
|
||||
|
@ -153,6 +153,28 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
|
||||
return Index(source_multidim_index);
|
||||
}
|
||||
|
||||
IrArray::Index IrArray::Index::SourceIndexOfSlice(
|
||||
const Shape& shape, tensorflow::gtl::ArraySlice<int64> starts,
|
||||
tensorflow::gtl::ArraySlice<int64> strides,
|
||||
llvm::IRBuilder<>* builder) const {
|
||||
Index source_index(multidim_.size());
|
||||
for (int i = 0; i < multidim_.size(); ++i) {
|
||||
int64 stride = strides[i];
|
||||
auto type = multidim_[i]->getType();
|
||||
|
||||
if (stride != 1) {
|
||||
source_index[i] = builder->CreateAdd(
|
||||
builder->CreateMul(multidim_[i],
|
||||
llvm::ConstantInt::get(type, stride)),
|
||||
llvm::ConstantInt::get(type, starts[i]));
|
||||
} else {
|
||||
source_index[i] = builder->CreateAdd(
|
||||
multidim_[i], llvm::ConstantInt::get(type, starts[i]));
|
||||
}
|
||||
}
|
||||
return source_index;
|
||||
}
|
||||
|
||||
IrArray::Index IrArray::Index::SourceIndexOfTranspose(
|
||||
const Shape& shape, const Shape& operand_shape,
|
||||
tensorflow::gtl::ArraySlice<int64> dimension_mapping,
|
||||
|
@ -115,6 +115,16 @@ class IrArray {
|
||||
Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape,
|
||||
llvm::IRBuilder<>* builder) const;
|
||||
|
||||
// Returns the index into the source operand from which a slice operation
|
||||
// selects a value to be placed into index "this". The slice is described
|
||||
// by starting indices `starts` and stride values `strides`.
|
||||
//
|
||||
// Precondition: "this" is an index into a slice whose shape is `shape`.
|
||||
Index SourceIndexOfSlice(const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<int64> starts,
|
||||
tensorflow::gtl::ArraySlice<int64> strides,
|
||||
llvm::IRBuilder<>* builder) const;
|
||||
|
||||
// Given that "this" is the target index of a transpose from `operand_shape`
|
||||
// to `shape` with the given dimension mapping, returns the source index.
|
||||
Index SourceIndexOfTranspose(
|
||||
|
Loading…
x
Reference in New Issue
Block a user