From 4bf7522056039f9850695690240de313c7af1829 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 13 Mar 2019 02:43:15 -0700 Subject: [PATCH] Add shape information for SourceIndexOfSlice. Replace the unused "shape" parameter of the output shape with an "operand_shape" parameter of the operand shape. Use this to create a source index that also has shape information. PiperOrigin-RevId: 238191074 --- tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 2 +- .../compiler/xla/service/elemental_ir_emitter.cc | 3 ++- tensorflow/compiler/xla/service/llvm_ir/ir_array.cc | 11 ++++++----- tensorflow/compiler/xla/service/llvm_ir/ir_array.h | 6 ++++-- tensorflow/compiler/xla/tests/fusion_test.cc | 4 ++-- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 52759ebd4ce..63ca3ea935f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -2011,7 +2011,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { llvm_ir::IrArray source_array = GetIrArrayFor(operand); const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( - /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(), + /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(), /*strides=*/slice->slice_strides(), /*builder=*/&b_); llvm::Value* memcpy_dest = diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index e0f3f9ff5b2..b547d66359c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -2357,7 +2357,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { IrArray::Index sliced_index = index.SourceIndexOfSlice( - /*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(), + /*operand_shape=*/hlo->operand(0)->shape(), + /*starts=*/hlo->slice_starts(), /*strides=*/hlo->slice_strides(), /*builder=*/b_); return operand_to_generator.at(hlo->operand(0))(sliced_index); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 8122651fb4f..744df21feb7 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -173,24 +173,25 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( } IrArray::Index IrArray::Index::SourceIndexOfSlice( - const Shape& shape, absl::Span starts, + const Shape& operand_shape, absl::Span starts, absl::Span strides, llvm::IRBuilder<>* builder) const { - Index source_index(index_type_, multidim_.size()); + std::vector source_multi_index(multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; auto type = multidim_[i]->getType(); if (stride != 1) { - source_index[i] = builder->CreateAdd( + source_multi_index[i] = builder->CreateAdd( builder->CreateMul(multidim_[i], llvm::ConstantInt::get(type, stride)), llvm::ConstantInt::get(type, starts[i])); } else { - source_index[i] = builder->CreateAdd( + source_multi_index[i] = builder->CreateAdd( multidim_[i], llvm::ConstantInt::get(type, starts[i])); } } - return source_index; + return Index(source_multi_index, /*linear=*/nullptr, operand_shape, + index_type_); } IrArray::Index IrArray::Index::SourceIndexOfTranspose( diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 6bc65fc7d1c..49f9b8e1591 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -133,8 +133,10 @@ class IrArray { // selects a value to be placed into index "this". The slice is described // by starting indices `starts` and stride values `strides`. // - // Precondition: "this" is an index into a slice whose shape is `shape`. - Index SourceIndexOfSlice(const Shape& shape, absl::Span starts, + // Precondition: "this" is an index into a slice whose operand shape is + // `operand_shape`. + Index SourceIndexOfSlice(const Shape& operand_shape, + absl::Span starts, absl::Span strides, llvm::IRBuilder<>* builder) const; diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 189736effb1..78908d0a449 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -194,7 +194,7 @@ XLA_TEST_F(FusionTest, Test) { // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -234,7 +234,7 @@ XLA_TEST_F(FusionTest, Test) { EXPECT_TRUE(LiteralTestUtil::Near( LiteralUtil::CreateR2({{0.5}, {2.72}}), - ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + ExecuteNoHloPasses(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions.