[XLA/GPU] Migrate all unnested elementwise emitters.
PiperOrigin-RevId: 346559170 Change-Id: I990590eb45fa5d9f866d05d66d27efcb5211fe42
This commit is contained in:
parent
e7365d08b2
commit
a393d15808
@ -83,6 +83,11 @@ enum ScalarLimit {
|
||||
// Requires `ty` to be either FloatType or IntegerType.
|
||||
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
|
||||
|
||||
// Given `op_name` from LMHLO, returns the corresponding op name in MHLO.
|
||||
// Returns empty string if no such op exists.
|
||||
std::string LmhloToMhloOpName(llvm::StringRef op_name,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -132,5 +132,13 @@ DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
|
||||
llvm_unreachable("unsupported type");
|
||||
}
|
||||
|
||||
std::string LmhloToMhloOpName(llvm::StringRef op_name,
|
||||
mlir::MLIRContext *context) {
|
||||
assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
|
||||
std::string mhlo_op_name(op_name.drop_front(1));
|
||||
if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
|
||||
return "";
|
||||
}
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
@ -40,11 +41,14 @@ limitations under the License.
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
@ -330,12 +334,22 @@ bool MayPreventVectorization(mlir::Operation* op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<mlir::Operation*> GetOutputOps(mlir::lmhlo::FusionOp fusion) {
|
||||
llvm::SetVector<mlir::Operation*> ops;
|
||||
for (mlir::Value output_value : fusion.getFusionResults()) {
|
||||
ops.insert(output_value.getDefiningOp());
|
||||
}
|
||||
return std::vector<mlir::Operation*>(ops.begin(), ops.end());
|
||||
}
|
||||
|
||||
// Computes the maximum valid unroll factor for a given instruction.
|
||||
int ComputeMaxUnrollFactor(const Shape& shape,
|
||||
const HloModuleConfig& hlo_module_config) {
|
||||
int max_unroll_factor =
|
||||
hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
|
||||
|
||||
// Find the largest possible power of two to unroll by.
|
||||
// TODO(kramerb): Make this smarter.
|
||||
int64 num_elements = ShapeUtil::ElementsIn(shape);
|
||||
for (int i = max_unroll_factor; i > 1; i /= 2) {
|
||||
if (num_elements % i == 0) {
|
||||
@ -349,14 +363,39 @@ int ComputeMaxUnrollFactor(const Shape& shape,
|
||||
|
||||
// Computes the maximum valid unroll factor for a given instruction.
|
||||
int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
|
||||
// Find the largest possible power of two to unroll by.
|
||||
// TODO(kramerb): Make this smarter.
|
||||
const Shape& element_shape = hlo->IsMultiOutputFusion()
|
||||
? ShapeUtil::GetSubshape(hlo->shape(), {0})
|
||||
: hlo->shape();
|
||||
return ComputeMaxUnrollFactor(element_shape, hlo->GetModule()->config());
|
||||
}
|
||||
|
||||
// Computes the maximum valid unroll factor for a given instruction.
|
||||
int ComputeMaxUnrollFactor(mlir::Operation* op,
|
||||
const HloModuleConfig& hlo_module_config) {
|
||||
Shape element_shape = [&] {
|
||||
std::vector<Shape> shapes;
|
||||
// Detect multi-output fusion. Notice that for a reduce in the fusion that
|
||||
// returns a tuple, we don't want to treat it as multi-output fusion. We
|
||||
// want to pass that tuple into ComputeMaxUnrollFactor below. For an actual
|
||||
// MOF, just pass the first element of the root tuple.
|
||||
if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
|
||||
std::vector<mlir::Operation*> fusion_outputs = GetOutputOps(fusion);
|
||||
for (mlir::Value result : fusion_outputs[0]->getResults()) {
|
||||
shapes.push_back(TypeToShape(result.getType()));
|
||||
}
|
||||
} else {
|
||||
for (mlir::Value result : op->getResults()) {
|
||||
shapes.push_back(TypeToShape(result.getType()));
|
||||
}
|
||||
}
|
||||
if (shapes.size() > 1) {
|
||||
return ShapeUtil::MakeTupleShape(shapes);
|
||||
}
|
||||
return shapes[0];
|
||||
}();
|
||||
return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
|
||||
}
|
||||
|
||||
// Returns the llvm type for the indices used in the kernel that contains the
|
||||
// hlo instruction. Such indices include the index for the parallel loop and
|
||||
// the indices for the tensors accessed by the kernel. The return type is i32
|
||||
@ -613,10 +652,14 @@ StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSliceForMlir(
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
|
||||
if (hlo->IsElementwise()) {
|
||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
|
||||
return EmitUsingElementalIrEmitter(input);
|
||||
}
|
||||
return IrEmitter::DefaultAction(hlo);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
|
||||
Status IrEmitterUnnested::EmitUsingElementalIrEmitter(MlirEmitterInput input) {
|
||||
// Replace unnested op with a fused nested op.
|
||||
//
|
||||
// TODO(timshen): Ultimately this should be a pass. It's currently not a pass,
|
||||
@ -670,19 +713,54 @@ Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
|
||||
output_shape = ShapeUtil::MakeTupleShape(output_shapes);
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Unimplemented default action for mlir op: "
|
||||
<< MlirToString(input.op);
|
||||
// Try to generically convert any LMHLO ops to LMHLO fusion + the
|
||||
// corresponding MHLO op. Currently we've only looked at elementwise ops and
|
||||
// they seem to be well covered.
|
||||
//
|
||||
// TODO(timshen): Moving forward, we should make it cover all ops if
|
||||
// possible, and only special-case the ones it can't.
|
||||
std::vector<mlir::Value> outputs;
|
||||
mlir::Operation* new_op;
|
||||
{
|
||||
std::vector<mlir::Value> operands;
|
||||
for (auto buffer : input.op->getOperands()) {
|
||||
if (WritesMlirBuffer(input.op, buffer)) {
|
||||
outputs.push_back(buffer);
|
||||
} else {
|
||||
operands.push_back(buffer);
|
||||
}
|
||||
}
|
||||
TF_RET_CHECK(outputs.size() == 1);
|
||||
|
||||
std::vector<mlir::Value> loads = load_memrefs(operands);
|
||||
std::string mhlo_op_name = mlir::hlo::LmhloToMhloOpName(
|
||||
input.op->getName().getStringRef(), input.op->getContext());
|
||||
TF_RET_CHECK(!mhlo_op_name.empty())
|
||||
<< "No corresponding MHLO op for given LMHLO op: "
|
||||
<< MlirToString(input.op);
|
||||
mlir::OperationState op_state(loc, mhlo_op_name);
|
||||
|
||||
mlir::BlockAndValueMapping mapper;
|
||||
for (mlir::Region& region : input.op->getRegions()) {
|
||||
mlir::Region* new_region = op_state.addRegion();
|
||||
region.cloneInto(new_region, mapper);
|
||||
}
|
||||
|
||||
op_state.addOperands(loads);
|
||||
op_state.addAttributes(input.op->getAttrs());
|
||||
op_state.addTypes({mlir::RankedTensorType::get(
|
||||
outputs[0].getType().cast<mlir::MemRefType>().getShape(),
|
||||
outputs[0].getType().cast<mlir::MemRefType>().getElementType())});
|
||||
new_op = b.createOperation(op_state);
|
||||
}
|
||||
TF_RET_CHECK(mlir::succeeded(mlir::verify(new_op)));
|
||||
output_shape = TypeToShape(outputs[0].getType());
|
||||
HloFunctionImporter::SetLayoutForMlir(new_op, output_shape);
|
||||
b.create<mlir::TensorStoreOp>(loc, new_op->getResult(0), outputs[0]);
|
||||
}
|
||||
input.op->erase();
|
||||
input.op = fusion;
|
||||
int unroll_factor = 1;
|
||||
// TODO(timshen): Port MayPreventVectorization as we add more ops into this
|
||||
// function.
|
||||
if (output_shape.IsArray()) {
|
||||
unroll_factor = ComputeMaxUnrollFactor(output_shape, hlo_module_config_);
|
||||
}
|
||||
auto ret = EmitLoopFusionFromMlir(input, output_shape, unroll_factor);
|
||||
return ret;
|
||||
return EmitLoopFusionFromMlir(input, output_shape);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
|
||||
@ -1210,8 +1288,7 @@ StatusOr<MlirEmitterInput> IrEmitterUnnested::GetMlirEmitterInput(
|
||||
// This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
|
||||
// subclass. The logic is de-virtualized and less scattered.
|
||||
Status IrEmitterUnnested::EmitLoopFusionFromMlir(MlirEmitterInput input,
|
||||
const Shape& output_shape,
|
||||
int unroll_factor) {
|
||||
const Shape& output_shape) {
|
||||
auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(input.op);
|
||||
MlirEmitterContext context;
|
||||
context.SetOperation(fusion);
|
||||
@ -1258,6 +1335,11 @@ Status IrEmitterUnnested::EmitLoopFusionFromMlir(MlirEmitterInput input,
|
||||
auto element_generator,
|
||||
fused_emitter.GetGenerator(fused_computation->root_instruction()));
|
||||
|
||||
int unroll_factor = 1;
|
||||
if (!MayPreventVectorization(fusion)) {
|
||||
unroll_factor = ComputeMaxUnrollFactor(fusion, hlo_module_config_);
|
||||
}
|
||||
|
||||
Shape element_shape = context.output_shapes[0];
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
|
||||
@ -1436,12 +1518,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int unroll_factor = 1;
|
||||
if (!MayPreventVectorization(*fusion)) {
|
||||
unroll_factor = ComputeMaxUnrollFactor(fusion);
|
||||
}
|
||||
|
||||
return EmitLoopFusionFromMlir(mlir_input, fusion->shape(), unroll_factor);
|
||||
return EmitLoopFusionFromMlir(mlir_input, fusion->shape());
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
|
||||
@ -1476,7 +1553,7 @@ Status IrEmitterUnnested::EmitCopyForMlir(MlirEmitterInput input) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return DefaultActionForMlir(input);
|
||||
return EmitUsingElementalIrEmitter(input);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
|
||||
@ -1507,7 +1584,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
|
||||
return EmitReductionFromOrToContiguousDimensions(mlir_input);
|
||||
}
|
||||
|
||||
return DefaultActionForMlir(mlir_input);
|
||||
return EmitUsingElementalIrEmitter(mlir_input);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
|
||||
|
@ -157,7 +157,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
}
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
Status DefaultActionForMlir(MlirEmitterInput input);
|
||||
Status EmitUsingElementalIrEmitter(MlirEmitterInput input);
|
||||
|
||||
// IrEmitterUnnested handles the following instructions differently from
|
||||
// IrEmitter. It also mixes in some special handling for custom kernels
|
||||
@ -175,7 +175,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleFft(HloInstruction* fft) override;
|
||||
Status HandleFusion(HloInstruction* fusion) override;
|
||||
Status EmitLoopFusionFromMlir(MlirEmitterInput input,
|
||||
const Shape& output_shape, int unroll_factor);
|
||||
const Shape& output_shape);
|
||||
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
|
||||
Status HandleReduce(HloInstruction* reduce) override;
|
||||
Status HandleSelectAndScatter(HloInstruction* instruction) override;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -56,7 +56,7 @@ void LlvmIrGenTestBase::CompileAndVerifyIr(
|
||||
|
||||
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
|
||||
TF_ASSERT_OK(filecheck_result.status());
|
||||
EXPECT_TRUE(filecheck_result.ValueOrDie());
|
||||
EXPECT_TRUE(filecheck_result.ValueOrDie()) << "Full IR: " << ir_;
|
||||
}
|
||||
|
||||
void LlvmIrGenTestBase::CompileAndVerifyIr(const string& hlo_text,
|
||||
@ -80,7 +80,7 @@ void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr(
|
||||
|
||||
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
|
||||
ASSERT_TRUE(filecheck_result.ok());
|
||||
EXPECT_TRUE(filecheck_result.ValueOrDie());
|
||||
EXPECT_TRUE(filecheck_result.ValueOrDie()) << "Full IR: " << ir_;
|
||||
}
|
||||
|
||||
void LlvmIrGenTestBase::MatchOptimizedHlo(absl::string_view hlo,
|
||||
|
Loading…
x
Reference in New Issue
Block a user