Add shape information for SourceIndexOfSlice.

Replace the unused "shape" parameter of the output shape with an
"operand_shape" parameter of the operand shape. Use this to create a source
index that also has shape information.

PiperOrigin-RevId: 238191074
This commit is contained in:
Adrian Kuegel 2019-03-13 02:43:15 -07:00 committed by TensorFlower Gardener
parent b3260dc6cc
commit 4bf7522056
5 changed files with 15 additions and 11 deletions

View File

@ -2011,7 +2011,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
llvm_ir::IrArray source_array = GetIrArrayFor(operand);
const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
/*shape=*/slice->shape(), /*starts=*/slice->slice_starts(),
/*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(),
/*strides=*/slice->slice_strides(), /*builder=*/&b_);
llvm::Value* memcpy_dest =

View File

@ -2357,7 +2357,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
IrArray::Index sliced_index = index.SourceIndexOfSlice(
/*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(),
/*operand_shape=*/hlo->operand(0)->shape(),
/*starts=*/hlo->slice_starts(),
/*strides=*/hlo->slice_strides(), /*builder=*/b_);
return operand_to_generator.at(hlo->operand(0))(sliced_index);
};

View File

@ -173,24 +173,25 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
}
IrArray::Index IrArray::Index::SourceIndexOfSlice(
const Shape& shape, absl::Span<const int64> starts,
const Shape& operand_shape, absl::Span<const int64> starts,
absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
Index source_index(index_type_, multidim_.size());
std::vector<llvm::Value*> source_multi_index(multidim_.size());
for (int i = 0; i < multidim_.size(); ++i) {
int64 stride = strides[i];
auto type = multidim_[i]->getType();
if (stride != 1) {
source_index[i] = builder->CreateAdd(
source_multi_index[i] = builder->CreateAdd(
builder->CreateMul(multidim_[i],
llvm::ConstantInt::get(type, stride)),
llvm::ConstantInt::get(type, starts[i]));
} else {
source_index[i] = builder->CreateAdd(
source_multi_index[i] = builder->CreateAdd(
multidim_[i], llvm::ConstantInt::get(type, starts[i]));
}
}
return source_index;
return Index(source_multi_index, /*linear=*/nullptr, operand_shape,
index_type_);
}
IrArray::Index IrArray::Index::SourceIndexOfTranspose(

View File

@ -133,8 +133,10 @@ class IrArray {
// selects a value to be placed into index "this". The slice is described
// by starting indices `starts` and stride values `strides`.
//
// Precondition: "this" is an index into a slice whose shape is `shape`.
Index SourceIndexOfSlice(const Shape& shape, absl::Span<const int64> starts,
// Precondition: "this" is an index into a slice whose operand shape is
// `operand_shape`.
Index SourceIndexOfSlice(const Shape& operand_shape,
absl::Span<const int64> starts,
absl::Span<const int64> strides,
llvm::IRBuilder<>* builder) const;

View File

@ -194,7 +194,7 @@ XLA_TEST_F(FusionTest, Test) {
// (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
// {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
@ -234,7 +234,7 @@ XLA_TEST_F(FusionTest, Test) {
EXPECT_TRUE(LiteralTestUtil::Near(
LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
ExecuteNoHloPasses(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
// Test whether we emit appropriate code for parameters of fusion instructions.