[XLA/GPU] Migrate fused DynamicUpdateSlice.
PiperOrigin-RevId: 354163573 Change-Id: Ie464cdb8bbc9a7dbef739652284f09276ed32af5
This commit is contained in:
parent
f40e2d37c5
commit
f03ee0a61d
@ -830,6 +830,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
@ -719,5 +718,100 @@ bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) {
|
||||
});
|
||||
}
|
||||
|
||||
static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) {
|
||||
// For i1 memrefs, the underlying allocation is 8 bits.
|
||||
if (type.getElementType().isInteger(/*width=*/1)) {
|
||||
return type.getNumElements();
|
||||
} else {
|
||||
return type.getSizeInBits() / CHAR_BIT;
|
||||
}
|
||||
}
|
||||
|
||||
static int64_t GetAllocationIndex(mlir::BlockArgument func_arg) {
|
||||
auto func_op =
|
||||
mlir::cast<mlir::FuncOp>(func_arg.getParentRegion()->getParentOp());
|
||||
return func_op
|
||||
.getArgAttrOfType<mlir::IntegerAttr>(func_arg.getArgNumber(),
|
||||
"lmhlo.alloc")
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
|
||||
mlir::Value v, absl::Span<const BufferAllocation> allocations) {
|
||||
int64 size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
|
||||
|
||||
if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
|
||||
return BufferAllocation::Slice(&allocations[GetAllocationIndex(arg)], 0,
|
||||
size);
|
||||
}
|
||||
|
||||
// We match the following patterns here:
|
||||
// base := ViewOp(arg) | get_global_memref (global_memref)
|
||||
// root := base | MemRefReinterpretCastOp(base)
|
||||
|
||||
if (mlir::Operation* op = v.getDefiningOp()) {
|
||||
if (auto cast = mlir::dyn_cast<mlir::MemRefReinterpretCastOp>(op)) {
|
||||
mlir::Value source = cast.getViewSource();
|
||||
op = source.getDefiningOp();
|
||||
if (!op) {
|
||||
return Unimplemented("MemRefReinterpretCastOp has to wrap an op");
|
||||
}
|
||||
}
|
||||
if (auto view = mlir::dyn_cast<mlir::ViewOp>(op)) {
|
||||
return BufferAllocation::Slice(
|
||||
&allocations[GetAllocationIndex(
|
||||
view.source().cast<mlir::BlockArgument>())],
|
||||
mlir::cast<mlir::ConstantOp>(view.byte_shift().getDefiningOp())
|
||||
.value()
|
||||
.cast<mlir::IntegerAttr>()
|
||||
.getValue()
|
||||
.getSExtValue(),
|
||||
size);
|
||||
} else if (auto get_global = mlir::dyn_cast<mlir::GetGlobalMemrefOp>(op)) {
|
||||
auto module = get_global->getParentOfType<mlir::ModuleOp>();
|
||||
auto global = mlir::cast<mlir::GlobalMemrefOp>(
|
||||
module.lookupSymbol(get_global.name()));
|
||||
int64_t index =
|
||||
global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
|
||||
return BufferAllocation::Slice(&allocations[index], 0,
|
||||
allocations[index].size());
|
||||
}
|
||||
return Unimplemented("MemRefReinterpretCastOp has to wrap a ViewOp");
|
||||
}
|
||||
|
||||
return Unimplemented(
|
||||
"Operand has to be in the form of ViewOp(arg) or "
|
||||
"StaticMemRefCastOp(ViewOp(arg))");
|
||||
}
|
||||
|
||||
bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
|
||||
mlir::lmhlo::FusionOp fusion,
|
||||
absl::Span<const BufferAllocation> allocations) {
|
||||
auto results = fusion.getFusionResults();
|
||||
if (results.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
auto dus = mlir::dyn_cast<mlir::mhlo::DynamicUpdateSliceOp>(
|
||||
results[0].getDefiningOp());
|
||||
if (!dus) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto output_buffers = fusion.getOutputBuffers();
|
||||
CHECK_EQ(1, output_buffers.size());
|
||||
auto parameter =
|
||||
mlir::dyn_cast<mlir::TensorLoadOp>(dus.operand().getDefiningOp());
|
||||
|
||||
if (!parameter) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto maybe_lhs = GetAllocationSliceForMlir(parameter.memref(), allocations);
|
||||
auto maybe_rhs = GetAllocationSliceForMlir(output_buffers[0], allocations);
|
||||
LOG(ERROR) << "TIM: ";
|
||||
return maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
@ -258,6 +260,13 @@ std::vector<T> ToStdVector(const llvm::SmallVectorImpl<T>& v) {
|
||||
return std::vector<T>(v.begin(), v.end());
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
|
||||
mlir::Value v, absl::Span<const BufferAllocation> allocations);
|
||||
|
||||
bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
|
||||
mlir::lmhlo::FusionOp fusion,
|
||||
absl::Span<const BufferAllocation> allocations);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -177,73 +177,6 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
|
||||
llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
|
||||
}
|
||||
|
||||
int64_t GetAllocationIndex(mlir::BlockArgument func_arg) {
|
||||
auto func_op =
|
||||
mlir::cast<mlir::FuncOp>(func_arg.getParentRegion()->getParentOp());
|
||||
return func_op
|
||||
.getArgAttrOfType<mlir::IntegerAttr>(func_arg.getArgNumber(),
|
||||
"lmhlo.alloc")
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
}
|
||||
|
||||
static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) {
|
||||
// For i1 memrefs, the underlying allocation is 8 bits.
|
||||
if (type.getElementType().isInteger(/*width=*/1)) {
|
||||
return type.getNumElements();
|
||||
} else {
|
||||
return type.getSizeInBits() / CHAR_BIT;
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
|
||||
mlir::Value v, absl::Span<const BufferAllocation> allocations) {
|
||||
int64 size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
|
||||
|
||||
if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
|
||||
return BufferAllocation::Slice(&allocations[GetAllocationIndex(arg)], 0,
|
||||
size);
|
||||
}
|
||||
|
||||
// We match the following patterns here:
|
||||
// base := ViewOp(arg) | get_global_memref (global_memref)
|
||||
// root := base | MemRefReinterpretCastOp(base)
|
||||
|
||||
if (mlir::Operation* op = v.getDefiningOp()) {
|
||||
if (auto cast = mlir::dyn_cast<mlir::MemRefReinterpretCastOp>(op)) {
|
||||
mlir::Value source = cast.getViewSource();
|
||||
op = source.getDefiningOp();
|
||||
if (!op) {
|
||||
return Unimplemented("MemRefReinterpretCastOp has to wrap an op");
|
||||
}
|
||||
}
|
||||
if (auto view = mlir::dyn_cast<mlir::ViewOp>(op)) {
|
||||
return BufferAllocation::Slice(
|
||||
&allocations[GetAllocationIndex(
|
||||
view.source().cast<mlir::BlockArgument>())],
|
||||
mlir::cast<mlir::ConstantOp>(view.byte_shift().getDefiningOp())
|
||||
.value()
|
||||
.cast<mlir::IntegerAttr>()
|
||||
.getValue()
|
||||
.getSExtValue(),
|
||||
size);
|
||||
} else if (auto get_global = mlir::dyn_cast<mlir::GetGlobalMemrefOp>(op)) {
|
||||
auto module = get_global->getParentOfType<mlir::ModuleOp>();
|
||||
auto global = mlir::cast<mlir::GlobalMemrefOp>(
|
||||
module.lookupSymbol(get_global.name()));
|
||||
int64_t index =
|
||||
global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
|
||||
return BufferAllocation::Slice(&allocations[index], 0,
|
||||
allocations[index].size());
|
||||
}
|
||||
return Unimplemented("MemRefReinterpretCastOp has to wrap a ViewOp");
|
||||
}
|
||||
|
||||
return Unimplemented(
|
||||
"Operand has to be in the form of ViewOp(arg) or "
|
||||
"StaticMemRefCastOp(ViewOp(arg))");
|
||||
}
|
||||
|
||||
bool BinarySearchDenseElementsAttr(::mlir::DenseIntElementsAttr elements,
|
||||
int64 v) {
|
||||
::mlir::APInt value(sizeof(int64) * 8, v, /*isSigned=*/true);
|
||||
@ -1837,6 +1770,7 @@ Status IrEmitterUnnested::EmitLoopFusionFromMlir(
|
||||
|
||||
Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(fusion));
|
||||
auto fusion_op = mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op);
|
||||
|
||||
HloInstruction* root = fusion->fused_expression_root();
|
||||
if (fusion->IsInputFusion()) {
|
||||
@ -1938,16 +1872,21 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
LOG(FATAL) << "Bad opcode for input fusion: "
|
||||
<< fusion->fused_expression_root()->opcode();
|
||||
}
|
||||
} else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(
|
||||
fusion, ir_emitter_context_->buffer_assignment())) {
|
||||
} else if (CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
|
||||
fusion_op, ir_emitter_context_->allocations())) {
|
||||
// Fusion node with dynamic-update-slice as the root where the op's input
|
||||
// (i.e. array to update) shares the same slice as its output. In this case
|
||||
// we have a special algorithm that modifies the output in place without
|
||||
// touching the un-updated elements.
|
||||
CHECK_EQ(1, GetHloOutputs(mlir_input.op).size());
|
||||
|
||||
// Set up kernel thunk and fused ir emitter.
|
||||
std::unique_ptr<KernelThunk> fusion_thunk =
|
||||
BuildKernelThunk(fusion, /*implements_whole_instruction=*/true);
|
||||
std::vector<llvm_ir::IrArray> ir_arrays;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto fusion_thunk,
|
||||
BuildKernelThunkForMlir(fusion_op, mlir_input.thunk_info,
|
||||
mlir_input.extra_slice, &ir_arrays));
|
||||
|
||||
GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
|
||||
ir_emitter_context_->llvm_module(),
|
||||
&b_, GetNestedComputer());
|
||||
@ -1957,7 +1896,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
|
||||
// Array to write into. Because this is an in-place operation, this is the
|
||||
// same as operand 0's array.
|
||||
IrArray output_array = GetIrArray(*fusion, *fusion);
|
||||
const IrArray& output_array = ir_arrays.back();
|
||||
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
update_shape, ir_emitter_context_->gpu_device_info());
|
||||
@ -1966,10 +1905,23 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
AddThunkToThunkSequence(std::move(fusion_thunk));
|
||||
|
||||
FusedIrEmitter fused_emitter(&elemental_emitter);
|
||||
BindFusionArguments(fusion, &fused_emitter);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const HloComputation* fused_computation,
|
||||
GetOrCreateSubComputationFromRegion(&fusion_op.region(),
|
||||
/*is_fusion=*/true));
|
||||
|
||||
for (int i = 0; i < fused_computation->num_parameters(); i++) {
|
||||
fused_emitter.BindGenerator(
|
||||
fused_computation->parameter_instruction(i),
|
||||
[this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
|
||||
return ir_arrays[i].EmitReadArrayElement(index, &b_);
|
||||
});
|
||||
}
|
||||
|
||||
return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
|
||||
fusion, output_array, &fused_emitter, launch_dimensions, &b_);
|
||||
fused_computation, output_array, &fused_emitter, launch_dimensions,
|
||||
&b_);
|
||||
}
|
||||
|
||||
CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
|
||||
|
@ -3,10 +3,10 @@
|
||||
// CHECK-LABEL: entry:
|
||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [50 x [96 x [1024 x half]]]*
|
||||
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_1]], i64 0
|
||||
// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_3]] to [50 x [96 x [1024 x half]]]*
|
||||
// CHECK: %[[VAL_5:.*]] = getelementptr inbounds i8, i8* %[[VAL_6:.*]], i64 0
|
||||
// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [1 x [96 x [1024 x half]]]*
|
||||
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
|
||||
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1 x [96 x [1024 x half]]]*
|
||||
// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_1]], i64 0
|
||||
// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_6]] to [50 x [96 x [1024 x half]]]*
|
||||
// CHECK: %[[VAL_8:.*]] = load i32, i32* bitcast ([4 x i8]* @0 to i32*), align 4
|
||||
// CHECK: %[[VAL_9:.*]] = icmp sge i32 0, %[[VAL_8]]
|
||||
// CHECK: %[[VAL_10:.*]] = select i1 %[[VAL_9]], i32 0, i32 %[[VAL_8]]
|
||||
@ -37,19 +37,19 @@
|
||||
// CHECK: %[[VAL_34:.*]] = udiv i64 %[[VAL_28]], 98304
|
||||
// CHECK: %[[VAL_35:.*]] = icmp ult i64 %[[VAL_28]], 98304
|
||||
// CHECK: br i1 %[[VAL_35]], label %[[VAL_36:.*]], label %[[VAL_37:.*]]
|
||||
// CHECK: f1.in_bounds-after: ; preds = %[[VAL_36]], %[[VAL_38:.*]]
|
||||
// CHECK: dynamic-update-slice.4.in_bounds-after: ; preds = %[[VAL_36]], %[[VAL_38:.*]]
|
||||
// CHECK: ret void
|
||||
// CHECK: f1.in_bounds-true: ; preds = %[[VAL_38]]
|
||||
// CHECK: dynamic-update-slice.4.in_bounds-true: ; preds = %[[VAL_38]]
|
||||
// CHECK: %[[VAL_39:.*]] = sext i32 %[[VAL_12]] to i64
|
||||
// CHECK: %[[VAL_40:.*]] = add i64 %[[VAL_39]], %[[VAL_34]]
|
||||
// CHECK: %[[VAL_41:.*]] = sext i32 %[[VAL_17]] to i64
|
||||
// CHECK: %[[VAL_42:.*]] = add i64 %[[VAL_41]], %[[VAL_33]]
|
||||
// CHECK: %[[VAL_43:.*]] = sext i32 %[[VAL_22]] to i64
|
||||
// CHECK: %[[VAL_44:.*]] = add i64 %[[VAL_43]], %[[VAL_31]]
|
||||
// CHECK: %[[VAL_45:.*]] = bitcast [1 x [96 x [1024 x half]]]* %[[VAL_7]] to half*
|
||||
// CHECK: %[[VAL_45:.*]] = bitcast [1 x [96 x [1024 x half]]]* %[[VAL_5]] to half*
|
||||
// CHECK: %[[VAL_46:.*]] = getelementptr inbounds half, half* %[[VAL_45]], i64 %[[VAL_28]]
|
||||
// CHECK: %[[VAL_47:.*]] = load half, half* %[[VAL_46]], align 2, !invariant.load !4
|
||||
// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [50 x [96 x [1024 x half]]], [50 x [96 x [1024 x half]]]* %[[VAL_2]], i64 0, i64 %[[VAL_40]], i64 %[[VAL_42]], i64 %[[VAL_44]]
|
||||
// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [50 x [96 x [1024 x half]]], [50 x [96 x [1024 x half]]]* %[[VAL_7]], i64 0, i64 %[[VAL_40]], i64 %[[VAL_42]], i64 %[[VAL_44]]
|
||||
// CHECK: store half %[[VAL_47]], half* %[[VAL_48]], align 2
|
||||
// CHECK: br label %[[VAL_37]]
|
||||
|
||||
|
@ -189,14 +189,12 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
|
||||
//
|
||||
// Emits a sequential loop if launch_dimensions is null.
|
||||
static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
|
||||
HloInstruction* fusion, const IrArray& fusion_output_array,
|
||||
const HloComputation* fusion, const IrArray& fusion_output_array,
|
||||
FusedIrEmitter* fused_emitter,
|
||||
const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
|
||||
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
|
||||
VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for "
|
||||
<< fusion->ToShortString();
|
||||
VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for " << fusion->ToString();
|
||||
|
||||
auto* dynamic_update_slice = fusion->fused_expression_root();
|
||||
auto* dynamic_update_slice = fusion->root_instruction();
|
||||
|
||||
const auto* update = dynamic_update_slice->operand(1);
|
||||
const auto* start_indices = dynamic_update_slice->operand(2);
|
||||
@ -215,8 +213,8 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
|
||||
// through the chain of ops that gives us the update operand and use the
|
||||
// layout of its source buffer(s). But this is no worse than we do with
|
||||
// fusion elsewhere.)
|
||||
TF_RETURN_IF_ERROR(
|
||||
LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape));
|
||||
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
|
||||
dynamic_update_slice->shape(), &update_shape));
|
||||
|
||||
// Create element generators for update and start_indices.
|
||||
TF_ASSIGN_OR_RETURN(ElementGenerator update_array_generator,
|
||||
@ -232,7 +230,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
|
||||
bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
|
||||
return EmitDynamicUpdateSliceInPlaceImpl(
|
||||
update_shape, start_indices_generator, is_signed, update_array_generator,
|
||||
fusion_output_array, launch_dimensions, IrName(fusion), b);
|
||||
fusion_output_array, launch_dimensions, IrName(dynamic_update_slice), b);
|
||||
}
|
||||
|
||||
Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
|
||||
@ -240,12 +238,12 @@ Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
|
||||
FusedIrEmitter* fused_emitter,
|
||||
llvm::IRBuilder<>* b) {
|
||||
return EmitFusedDynamicUpdateSliceInPlaceImpl(
|
||||
fusion, fusion_output_array, fused_emitter,
|
||||
fusion->called_computations()[0], fusion_output_array, fused_emitter,
|
||||
/*launch_dimensions=*/nullptr, b);
|
||||
}
|
||||
|
||||
Status EmitParallelFusedDynamicUpdateSliceInPlace(
|
||||
HloInstruction* fusion, const IrArray& fusion_output_array,
|
||||
const HloComputation* fusion, const IrArray& fusion_output_array,
|
||||
FusedIrEmitter* fused_emitter,
|
||||
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
|
||||
return EmitFusedDynamicUpdateSliceInPlaceImpl(
|
||||
|
@ -80,7 +80,7 @@ Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
|
||||
// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with
|
||||
// the given launch dimensions.
|
||||
Status EmitParallelFusedDynamicUpdateSliceInPlace(
|
||||
HloInstruction* fusion, const IrArray& fusion_output_array,
|
||||
const HloComputation* fusion, const IrArray& fusion_output_array,
|
||||
FusedIrEmitter* fused_emitter,
|
||||
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user