Add shape information for SourceIndexOfSlice.
Replace the unused "shape" parameter of the output shape with an "operand_shape" parameter of the operand shape. Use this to create a source index that also has shape information. PiperOrigin-RevId: 238191074
This commit is contained in:
parent
b3260dc6cc
commit
4bf7522056
@ -2011,7 +2011,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
|
||||
|
||||
llvm_ir::IrArray source_array = GetIrArrayFor(operand);
|
||||
const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
|
||||
/*shape=*/slice->shape(), /*starts=*/slice->slice_starts(),
|
||||
/*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(),
|
||||
/*strides=*/slice->slice_strides(), /*builder=*/&b_);
|
||||
|
||||
llvm::Value* memcpy_dest =
|
||||
|
@ -2357,7 +2357,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
return [this, hlo, &operand_to_generator](
|
||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
IrArray::Index sliced_index = index.SourceIndexOfSlice(
|
||||
/*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(),
|
||||
/*operand_shape=*/hlo->operand(0)->shape(),
|
||||
/*starts=*/hlo->slice_starts(),
|
||||
/*strides=*/hlo->slice_strides(), /*builder=*/b_);
|
||||
return operand_to_generator.at(hlo->operand(0))(sliced_index);
|
||||
};
|
||||
|
@ -173,24 +173,25 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
|
||||
}
|
||||
|
||||
IrArray::Index IrArray::Index::SourceIndexOfSlice(
|
||||
const Shape& shape, absl::Span<const int64> starts,
|
||||
const Shape& operand_shape, absl::Span<const int64> starts,
|
||||
absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
|
||||
Index source_index(index_type_, multidim_.size());
|
||||
std::vector<llvm::Value*> source_multi_index(multidim_.size());
|
||||
for (int i = 0; i < multidim_.size(); ++i) {
|
||||
int64 stride = strides[i];
|
||||
auto type = multidim_[i]->getType();
|
||||
|
||||
if (stride != 1) {
|
||||
source_index[i] = builder->CreateAdd(
|
||||
source_multi_index[i] = builder->CreateAdd(
|
||||
builder->CreateMul(multidim_[i],
|
||||
llvm::ConstantInt::get(type, stride)),
|
||||
llvm::ConstantInt::get(type, starts[i]));
|
||||
} else {
|
||||
source_index[i] = builder->CreateAdd(
|
||||
source_multi_index[i] = builder->CreateAdd(
|
||||
multidim_[i], llvm::ConstantInt::get(type, starts[i]));
|
||||
}
|
||||
}
|
||||
return source_index;
|
||||
return Index(source_multi_index, /*linear=*/nullptr, operand_shape,
|
||||
index_type_);
|
||||
}
|
||||
|
||||
IrArray::Index IrArray::Index::SourceIndexOfTranspose(
|
||||
|
@ -133,8 +133,10 @@ class IrArray {
|
||||
// selects a value to be placed into index "this". The slice is described
|
||||
// by starting indices `starts` and stride values `strides`.
|
||||
//
|
||||
// Precondition: "this" is an index into a slice whose shape is `shape`.
|
||||
Index SourceIndexOfSlice(const Shape& shape, absl::Span<const int64> starts,
|
||||
// Precondition: "this" is an index into a slice whose operand shape is
|
||||
// `operand_shape`.
|
||||
Index SourceIndexOfSlice(const Shape& operand_shape,
|
||||
absl::Span<const int64> starts,
|
||||
absl::Span<const int64> strides,
|
||||
llvm::IRBuilder<>* builder) const;
|
||||
|
||||
|
@ -194,7 +194,7 @@ XLA_TEST_F(FusionTest, Test) {
|
||||
// (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
|
||||
// {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
|
||||
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
@ -234,7 +234,7 @@ XLA_TEST_F(FusionTest, Test) {
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
|
||||
ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
|
||||
ExecuteNoHloPasses(std::move(hlo_module), {}), ErrorSpec(1e-4)));
|
||||
}
|
||||
|
||||
// Test whether we emit appropriate code for parameters of fusion instructions.
|
||||
|
Loading…
x
Reference in New Issue
Block a user