Merge pull request #42683 from nouiz:upstream_master_grid_size
PiperOrigin-RevId: 335609206 Change-Id: Iad371a188dd774bf1293eb126921189a54f5ffba
This commit is contained in:
commit
b49b04b9cc
@ -31,9 +31,15 @@ ParallelLoopEmitter::ParallelLoopEmitter(
|
|||||||
|
|
||||||
std::vector<llvm_ir::IrArray::Index>
|
std::vector<llvm_ir::IrArray::Index>
|
||||||
ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
||||||
llvm::Type* index_type) {
|
llvm::Type* index_type,
|
||||||
|
llvm::Value* base_index) {
|
||||||
CHECK_NE(index_type, nullptr);
|
CHECK_NE(index_type, nullptr);
|
||||||
|
|
||||||
|
CHECK_EQ(base_index, nullptr)
|
||||||
|
<< "XLA CPU implementation of"
|
||||||
|
<< " ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support"
|
||||||
|
<< " base_index, but it was requested.";
|
||||||
|
|
||||||
CHECK(!shape_.IsTuple());
|
CHECK(!shape_.IsTuple());
|
||||||
CHECK(!ShapeUtil::IsScalar(shape_));
|
CHECK(!ShapeUtil::IsScalar(shape_));
|
||||||
|
|
||||||
|
@ -61,7 +61,8 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
|
|||||||
~ParallelLoopEmitter() override = default;
|
~ParallelLoopEmitter() override = default;
|
||||||
|
|
||||||
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
|
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
|
||||||
absl::string_view loop_name, llvm::Type* index_type) override;
|
absl::string_view loop_name, llvm::Type* index_type,
|
||||||
|
llvm::Value* base_index) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const DynamicLoopBounds* dynamic_loop_bounds_;
|
const DynamicLoopBounds* dynamic_loop_bounds_;
|
||||||
|
@ -323,6 +323,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
||||||
|
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
|
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
|
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
|
||||||
|
@ -610,6 +610,11 @@ static StatusOr<bool> DeviceCompare(se::Stream* stream,
|
|||||||
executor->GetDeviceDescription().threads_per_block_limit();
|
executor->GetDeviceDescription().threads_per_block_limit();
|
||||||
gpu_device_info.threads_per_warp =
|
gpu_device_info.threads_per_warp =
|
||||||
executor->GetDeviceDescription().threads_per_warp();
|
executor->GetDeviceDescription().threads_per_warp();
|
||||||
|
gpu_device_info.shared_memory_per_block =
|
||||||
|
executor->GetDeviceDescription().shared_memory_per_block();
|
||||||
|
gpu_device_info.threads_per_core_limit =
|
||||||
|
executor->GetDeviceDescription().threads_per_core_limit();
|
||||||
|
gpu_device_info.core_count = executor->GetDeviceDescription().core_count();
|
||||||
LaunchDimensions dim =
|
LaunchDimensions dim =
|
||||||
CalculateLaunchDimensions(buffer_shape, gpu_device_info);
|
CalculateLaunchDimensions(buffer_shape, gpu_device_info);
|
||||||
|
|
||||||
|
@ -617,6 +617,9 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
|||||||
stream_exec->GetDeviceDescription().threads_per_warp();
|
stream_exec->GetDeviceDescription().threads_per_warp();
|
||||||
gpu_device_info.shared_memory_per_block =
|
gpu_device_info.shared_memory_per_block =
|
||||||
stream_exec->GetDeviceDescription().shared_memory_per_block();
|
stream_exec->GetDeviceDescription().shared_memory_per_block();
|
||||||
|
gpu_device_info.threads_per_core_limit =
|
||||||
|
stream_exec->GetDeviceDescription().threads_per_core_limit();
|
||||||
|
gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count();
|
||||||
|
|
||||||
absl::optional<CudaComputeCapability> cuda_compute_capability =
|
absl::optional<CudaComputeCapability> cuda_compute_capability =
|
||||||
[&]() -> absl::optional<CudaComputeCapability> {
|
[&]() -> absl::optional<CudaComputeCapability> {
|
||||||
|
@ -32,6 +32,8 @@ struct GpuDeviceInfo {
|
|||||||
int threads_per_block_limit;
|
int threads_per_block_limit;
|
||||||
int threads_per_warp;
|
int threads_per_warp;
|
||||||
int shared_memory_per_block;
|
int shared_memory_per_block;
|
||||||
|
int threads_per_core_limit;
|
||||||
|
int core_count;
|
||||||
};
|
};
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -2303,7 +2303,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(
|
|||||||
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
||||||
const HloInstruction& hlo,
|
const HloInstruction& hlo,
|
||||||
const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk,
|
const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk,
|
||||||
int unroll_factor) {
|
int unroll_factor, bool few_waves) {
|
||||||
VLOG(3) << bindings_.ToString();
|
VLOG(3) << bindings_.ToString();
|
||||||
|
|
||||||
bool multi_output = hlo.shape().IsTuple();
|
bool multi_output = hlo.shape().IsTuple();
|
||||||
@ -2314,7 +2314,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
|||||||
<< ShapeUtil::HumanStringWithLayout(hlo.shape())
|
<< ShapeUtil::HumanStringWithLayout(hlo.shape())
|
||||||
<< " for unroll_factor " << unroll_factor;
|
<< " for unroll_factor " << unroll_factor;
|
||||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||||
element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
|
element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor,
|
||||||
|
few_waves);
|
||||||
UpdateLaunchDimensions(launch_dimensions, thunk,
|
UpdateLaunchDimensions(launch_dimensions, thunk,
|
||||||
ir_emitter_context_->llvm_module());
|
ir_emitter_context_->llvm_module());
|
||||||
if (!multi_output) {
|
if (!multi_output) {
|
||||||
@ -2400,8 +2401,27 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
|
|||||||
|
|
||||||
std::unique_ptr<KernelThunk> kernel_thunk =
|
std::unique_ptr<KernelThunk> kernel_thunk =
|
||||||
BuildKernelThunk(&hlo, /*implements_whole_instruction=*/true);
|
BuildKernelThunk(&hlo, /*implements_whole_instruction=*/true);
|
||||||
|
|
||||||
|
// Check if we want to schedule grid size that has fewer SM waves.
|
||||||
|
// This speed up computations in some cases.
|
||||||
|
bool few_waves = false;
|
||||||
|
auto few_waves_allow_instr = [](const HloInstruction* instr) {
|
||||||
|
return instr->IsElementwise() || instr->opcode() == HloOpcode::kParameter ||
|
||||||
|
// We need to make the codegen broadcast aware before enabling
|
||||||
|
// more broadcast pattern.
|
||||||
|
(instr->opcode() == HloOpcode::kBroadcast &&
|
||||||
|
instr->dimensions().empty());
|
||||||
|
};
|
||||||
|
if (hlo.opcode() == HloOpcode::kFusion) {
|
||||||
|
few_waves =
|
||||||
|
absl::c_all_of(hlo.fused_instructions_computation()->instructions(),
|
||||||
|
few_waves_allow_instr);
|
||||||
|
} else {
|
||||||
|
few_waves = few_waves_allow_instr(&hlo);
|
||||||
|
}
|
||||||
|
|
||||||
Status emit_status = EmitTargetElementLoopInThunk(
|
Status emit_status = EmitTargetElementLoopInThunk(
|
||||||
hlo, body_emitter, kernel_thunk.get(), unroll_factor);
|
hlo, body_emitter, kernel_thunk.get(), unroll_factor, few_waves);
|
||||||
thunk_sequence_.emplace_back(std::move(kernel_thunk));
|
thunk_sequence_.emplace_back(std::move(kernel_thunk));
|
||||||
|
|
||||||
return emit_status;
|
return emit_status;
|
||||||
|
@ -178,7 +178,7 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
// `unroll_factor` is greater than one.
|
// `unroll_factor` is greater than one.
|
||||||
Status EmitTargetElementLoopInThunk(
|
Status EmitTargetElementLoopInThunk(
|
||||||
const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
|
const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
|
||||||
KernelThunk* thunk, int unroll_factor);
|
KernelThunk* thunk, int unroll_factor, bool few_waves = false);
|
||||||
|
|
||||||
Status Postprocess(HloInstruction* hlo) override;
|
Status Postprocess(HloInstruction* hlo) override;
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ static int64 ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) {
|
|||||||
// Calculates the launch dimensions used to invoke `hlo`.
|
// Calculates the launch dimensions used to invoke `hlo`.
|
||||||
LaunchDimensions CalculateLaunchDimensions(const Shape& shape,
|
LaunchDimensions CalculateLaunchDimensions(const Shape& shape,
|
||||||
GpuDeviceInfo gpu_device_info,
|
GpuDeviceInfo gpu_device_info,
|
||||||
int unroll_factor) {
|
int unroll_factor, bool few_waves) {
|
||||||
int64 num_elements = ShapeUtil::ElementsIn(shape);
|
int64 num_elements = ShapeUtil::ElementsIn(shape);
|
||||||
if (num_elements <= 1) {
|
if (num_elements <= 1) {
|
||||||
return LaunchDimensions();
|
return LaunchDimensions();
|
||||||
@ -90,6 +90,11 @@ LaunchDimensions CalculateLaunchDimensions(const Shape& shape,
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64 block_count = CeilOfRatio(num_elements, threads_per_block);
|
int64 block_count = CeilOfRatio(num_elements, threads_per_block);
|
||||||
|
if (few_waves) {
|
||||||
|
threads_per_block = std::min(threads_per_block, int64{128});
|
||||||
|
block_count = gpu_device_info.core_count *
|
||||||
|
(gpu_device_info.threads_per_core_limit / threads_per_block);
|
||||||
|
}
|
||||||
VLOG(2) << absl::StrFormat(
|
VLOG(2) << absl::StrFormat(
|
||||||
"Initialized the block count to ceil(# of elements / threads per "
|
"Initialized the block count to ceil(# of elements / threads per "
|
||||||
"block) = ceil(%d/%d) = %d",
|
"block) = ceil(%d/%d) = %d",
|
||||||
|
@ -67,7 +67,8 @@ std::ostream& operator<<(std::ostream& out,
|
|||||||
|
|
||||||
LaunchDimensions CalculateLaunchDimensions(const Shape& shape,
|
LaunchDimensions CalculateLaunchDimensions(const Shape& shape,
|
||||||
GpuDeviceInfo gpu_device_info,
|
GpuDeviceInfo gpu_device_info,
|
||||||
int unroll_factor = 1);
|
int unroll_factor = 1,
|
||||||
|
bool few_waves = false);
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "llvm/IR/Intrinsics.h"
|
#include "llvm/IR/Intrinsics.h"
|
||||||
#include "llvm/IR/Value.h"
|
#include "llvm/IR/Value.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/target_util.h"
|
#include "tensorflow/compiler/xla/service/gpu/target_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
@ -58,7 +59,8 @@ ParallelLoopEmitter::ParallelLoopEmitter(
|
|||||||
|
|
||||||
std::vector<llvm_ir::IrArray::Index>
|
std::vector<llvm_ir::IrArray::Index>
|
||||||
ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
||||||
llvm::Type* index_type) {
|
llvm::Type* index_type,
|
||||||
|
llvm::Value* base_index) {
|
||||||
// Emit the following code in LLVM IR:
|
// Emit the following code in LLVM IR:
|
||||||
// linear_index = blockIdx.x * blockDim.x + threadIdx.x;
|
// linear_index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
// if (linear_index < num_elements) {
|
// if (linear_index < num_elements) {
|
||||||
@ -122,6 +124,12 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
|||||||
"linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true);
|
"linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (base_index != nullptr) {
|
||||||
|
linear_index_base =
|
||||||
|
b_->CreateAdd(linear_index_base, base_index, "linear_index_plus_base",
|
||||||
|
/*HasNUW=*/true, /*HasNSW=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
array_indices.emplace_back(linear_index_base, shape_, b_);
|
array_indices.emplace_back(linear_index_base, shape_, b_);
|
||||||
for (int i = 1; i < unroll_factor_; ++i) {
|
for (int i = 1; i < unroll_factor_; ++i) {
|
||||||
llvm::Value* linear_index =
|
llvm::Value* linear_index =
|
||||||
@ -147,5 +155,43 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
|||||||
return array_indices;
|
return array_indices;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name,
|
||||||
|
llvm::Type* index_type) {
|
||||||
|
if (index_type == nullptr) {
|
||||||
|
index_type = b_->getInt64Ty();
|
||||||
|
}
|
||||||
|
int64 total_threads = launch_dimensions_.launch_bound();
|
||||||
|
int64 num_elements = ShapeUtil::ElementsIn(shape_);
|
||||||
|
// If all the elements are handled by the current threads, no need
|
||||||
|
// to add a loop inside the kernel.
|
||||||
|
if (total_threads * unroll_factor_ >= num_elements) {
|
||||||
|
VLOG(1) << "ParallelLoopEmitter::EmitLoop fallback";
|
||||||
|
return LoopEmitter::EmitLoop(loop_name, index_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelSupportLibrary ksl(b_, llvm_ir::UnrollMode::kDefaultUnroll);
|
||||||
|
auto constant = [&](int64 val) {
|
||||||
|
return llvm::ConstantInt::get(index_type, val);
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(ksl.ForWithStatus(
|
||||||
|
"loop", constant(0), constant(num_elements),
|
||||||
|
constant(total_threads * unroll_factor_), [&](llvm::Value* base_indvar) {
|
||||||
|
for (const llvm_ir::IrArray::Index& array_index :
|
||||||
|
EmitIndexAndSetExitBasicBlock(loop_name, index_type,
|
||||||
|
base_indvar)) {
|
||||||
|
TF_RETURN_IF_ERROR(body_emitter_(array_index));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Set the insertion point of b_ to the loop exit, so that
|
||||||
|
// code emitted for later instructions will be correctly placed.
|
||||||
|
if (exit_bb_ != nullptr) {
|
||||||
|
b_->SetInsertPoint(exit_bb_);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -57,7 +57,11 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
|
|||||||
~ParallelLoopEmitter() override = default;
|
~ParallelLoopEmitter() override = default;
|
||||||
|
|
||||||
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
|
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
|
||||||
absl::string_view loop_name, llvm::Type* index_type) override;
|
absl::string_view loop_name, llvm::Type* index_type,
|
||||||
|
llvm::Value* base_index) override;
|
||||||
|
|
||||||
|
Status EmitLoop(absl::string_view loop_name = "",
|
||||||
|
llvm::Type* index_type = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The thread and block dimension to parallelize the loop on.
|
// The thread and block dimension to parallelize the loop on.
|
||||||
|
@ -148,8 +148,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedSine) {
|
|||||||
HloModule test_module
|
HloModule test_module
|
||||||
|
|
||||||
ENTRY SineFunc {
|
ENTRY SineFunc {
|
||||||
p0 = f32[160000]{0} parameter(0)
|
p0 = f32[1600000]{0} parameter(0)
|
||||||
ROOT s = f32[160000]{0} sine(p0)
|
ROOT s = f32[1600000]{0} sine(p0)
|
||||||
})";
|
})";
|
||||||
auto hlo_module =
|
auto hlo_module =
|
||||||
ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie();
|
ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie();
|
||||||
@ -182,8 +182,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedCosine) {
|
|||||||
HloModule test_module
|
HloModule test_module
|
||||||
|
|
||||||
ENTRY SineFunc {
|
ENTRY SineFunc {
|
||||||
p0 = f32[160000]{0} parameter(0)
|
p0 = f32[1600000]{0} parameter(0)
|
||||||
ROOT s = f32[160000]{0} cosine(p0)
|
ROOT s = f32[1600000]{0} cosine(p0)
|
||||||
})";
|
})";
|
||||||
auto hlo_module =
|
auto hlo_module =
|
||||||
ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie();
|
ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie();
|
||||||
@ -216,8 +216,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedPower) {
|
|||||||
HloModule test_module
|
HloModule test_module
|
||||||
|
|
||||||
ENTRY SineFunc {
|
ENTRY SineFunc {
|
||||||
p0 = f32[160000]{0} parameter(0)
|
p0 = f32[1600000]{0} parameter(0)
|
||||||
ROOT s = f32[160000]{0} power(p0, p0)
|
ROOT s = f32[1600000]{0} power(p0, p0)
|
||||||
})";
|
})";
|
||||||
auto hlo_module =
|
auto hlo_module =
|
||||||
ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie();
|
ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie();
|
||||||
@ -241,8 +241,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedAtan2) {
|
|||||||
HloModule test_module
|
HloModule test_module
|
||||||
|
|
||||||
ENTRY SineFunc {
|
ENTRY SineFunc {
|
||||||
p0 = f32[160000]{0} parameter(0)
|
p0 = f32[1600000]{0} parameter(0)
|
||||||
ROOT s = f32[160000]{0} atan2(p0, p0)
|
ROOT s = f32[1600000]{0} atan2(p0, p0)
|
||||||
})";
|
})";
|
||||||
auto hlo_module =
|
auto hlo_module =
|
||||||
ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie();
|
ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie();
|
||||||
|
@ -130,8 +130,14 @@ IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
|
std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
|
||||||
absl::string_view loop_name, llvm::Type* index_type) {
|
absl::string_view loop_name, llvm::Type* index_type,
|
||||||
|
llvm::Value* base_index) {
|
||||||
CHECK_NE(index_type, nullptr);
|
CHECK_NE(index_type, nullptr);
|
||||||
|
CHECK_EQ(base_index, nullptr)
|
||||||
|
<< "XLA CPU implementation of"
|
||||||
|
<< " LoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support"
|
||||||
|
<< " base_index, but it was requested.";
|
||||||
|
|
||||||
if (ShapeUtil::IsScalar(shape_)) {
|
if (ShapeUtil::IsScalar(shape_)) {
|
||||||
// No loop needed, so set exit_bb_ to nullptr.
|
// No loop needed, so set exit_bb_ to nullptr.
|
||||||
exit_bb_ = nullptr;
|
exit_bb_ = nullptr;
|
||||||
@ -164,7 +170,8 @@ Status LoopEmitter::EmitLoop(absl::string_view loop_name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (const IrArray::Index& array_index :
|
for (const IrArray::Index& array_index :
|
||||||
EmitIndexAndSetExitBasicBlock(loop_name, index_type)) {
|
EmitIndexAndSetExitBasicBlock(loop_name, index_type,
|
||||||
|
/*base_index*/ nullptr)) {
|
||||||
TF_RETURN_IF_ERROR(body_emitter_(array_index));
|
TF_RETURN_IF_ERROR(body_emitter_(array_index));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,11 +71,13 @@ class LoopEmitter {
|
|||||||
// specifies the element, will return multiple indices if the loop is
|
// specifies the element, will return multiple indices if the loop is
|
||||||
// unrolled.
|
// unrolled.
|
||||||
std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock() {
|
std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock() {
|
||||||
return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", b_->getInt64Ty());
|
return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", b_->getInt64Ty(),
|
||||||
|
/*base_index*/ nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock(
|
virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock(
|
||||||
absl::string_view loop_name, llvm::Type* index_type);
|
absl::string_view loop_name, llvm::Type* index_type,
|
||||||
|
llvm::Value* base_index);
|
||||||
|
|
||||||
// Emits a complete loop nest for every element in the given shape.
|
// Emits a complete loop nest for every element in the given shape.
|
||||||
Status EmitLoop(absl::string_view loop_name = "",
|
Status EmitLoop(absl::string_view loop_name = "",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user