[XLA/GPU] Migrate all unnested elementwise emitters.

PiperOrigin-RevId: 346559170
Change-Id: I990590eb45fa5d9f866d05d66d27efcb5211fe42
This commit is contained in:
Tim Shen 2020-12-09 08:42:11 -08:00 committed by TensorFlower Gardener
parent e7365d08b2
commit a393d15808
6 changed files with 672 additions and 582 deletions

View File

@ -83,6 +83,11 @@ enum ScalarLimit {
// Requires `ty` to be either FloatType or IntegerType.
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
// Given `op_name` from LMHLO, returns the corresponding op name in MHLO.
// Returns empty string if no such op exists.
std::string LmhloToMhloOpName(llvm::StringRef op_name,
mlir::MLIRContext* context);
} // namespace hlo
} // namespace mlir

View File

@ -132,5 +132,13 @@ DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
llvm_unreachable("unsupported type");
}
std::string LmhloToMhloOpName(llvm::StringRef op_name,
mlir::MLIRContext *context) {
assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
std::string mhlo_op_name(op_name.drop_front(1));
if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
return "";
}
} // namespace hlo
} // namespace mlir

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
@ -40,11 +41,14 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
@ -330,12 +334,22 @@ bool MayPreventVectorization(mlir::Operation* op) {
return true;
}
std::vector<mlir::Operation*> GetOutputOps(mlir::lmhlo::FusionOp fusion) {
llvm::SetVector<mlir::Operation*> ops;
for (mlir::Value output_value : fusion.getFusionResults()) {
ops.insert(output_value.getDefiningOp());
}
return std::vector<mlir::Operation*>(ops.begin(), ops.end());
}
// Computes the maximum valid unroll factor for a given instruction.
int ComputeMaxUnrollFactor(const Shape& shape,
const HloModuleConfig& hlo_module_config) {
int max_unroll_factor =
hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
// Find the largest possible power of two to unroll by.
// TODO(kramerb): Make this smarter.
int64 num_elements = ShapeUtil::ElementsIn(shape);
for (int i = max_unroll_factor; i > 1; i /= 2) {
if (num_elements % i == 0) {
@ -349,14 +363,39 @@ int ComputeMaxUnrollFactor(const Shape& shape,
// Computes the maximum valid unroll factor for a given instruction.
int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
// Find the largest possible power of two to unroll by.
// TODO(kramerb): Make this smarter.
const Shape& element_shape = hlo->IsMultiOutputFusion()
? ShapeUtil::GetSubshape(hlo->shape(), {0})
: hlo->shape();
return ComputeMaxUnrollFactor(element_shape, hlo->GetModule()->config());
}
// Computes the maximum valid unroll factor for a given instruction.
int ComputeMaxUnrollFactor(mlir::Operation* op,
const HloModuleConfig& hlo_module_config) {
Shape element_shape = [&] {
std::vector<Shape> shapes;
// Detect multi-output fusion. Notice that for a reduce in the fusion that
// returns a tuple, we don't want to treat it as multi-output fusion. We
// want to pass that tuple into ComputeMaxUnrollFactor below. For an actual
// MOF, just pass the first element of the root tuple.
if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
std::vector<mlir::Operation*> fusion_outputs = GetOutputOps(fusion);
for (mlir::Value result : fusion_outputs[0]->getResults()) {
shapes.push_back(TypeToShape(result.getType()));
}
} else {
for (mlir::Value result : op->getResults()) {
shapes.push_back(TypeToShape(result.getType()));
}
}
if (shapes.size() > 1) {
return ShapeUtil::MakeTupleShape(shapes);
}
return shapes[0];
}();
return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
}
// Returns the llvm type for the indices used in the kernel that contains the
// hlo instruction. Such indices include the index for the parallel loop and
// the indices for the tensors accessed by the kernel. The return type is i32
@ -613,10 +652,14 @@ StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSliceForMlir(
}
Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
if (hlo->IsElementwise()) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitUsingElementalIrEmitter(input);
}
return IrEmitter::DefaultAction(hlo);
}
Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
Status IrEmitterUnnested::EmitUsingElementalIrEmitter(MlirEmitterInput input) {
// Replace unnested op with a fused nested op.
//
// TODO(timshen): Ultimately this should be a pass. It's currently not a pass,
@ -670,19 +713,54 @@ Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
output_shape = ShapeUtil::MakeTupleShape(output_shapes);
}
} else {
LOG(FATAL) << "Unimplemented default action for mlir op: "
<< MlirToString(input.op);
// Try to generically convert any LMHLO ops to LMHLO fusion + the
// corresponding MHLO op. Currently we've only looked at elementwise ops and
// they seem to be well covered.
//
// TODO(timshen): Moving forward, we should make it cover all ops if
// possible, and only special-case the ones it can't.
std::vector<mlir::Value> outputs;
mlir::Operation* new_op;
{
std::vector<mlir::Value> operands;
for (auto buffer : input.op->getOperands()) {
if (WritesMlirBuffer(input.op, buffer)) {
outputs.push_back(buffer);
} else {
operands.push_back(buffer);
}
}
TF_RET_CHECK(outputs.size() == 1);
std::vector<mlir::Value> loads = load_memrefs(operands);
std::string mhlo_op_name = mlir::hlo::LmhloToMhloOpName(
input.op->getName().getStringRef(), input.op->getContext());
TF_RET_CHECK(!mhlo_op_name.empty())
<< "No corresponding MHLO op for given LMHLO op: "
<< MlirToString(input.op);
mlir::OperationState op_state(loc, mhlo_op_name);
mlir::BlockAndValueMapping mapper;
for (mlir::Region& region : input.op->getRegions()) {
mlir::Region* new_region = op_state.addRegion();
region.cloneInto(new_region, mapper);
}
op_state.addOperands(loads);
op_state.addAttributes(input.op->getAttrs());
op_state.addTypes({mlir::RankedTensorType::get(
outputs[0].getType().cast<mlir::MemRefType>().getShape(),
outputs[0].getType().cast<mlir::MemRefType>().getElementType())});
new_op = b.createOperation(op_state);
}
TF_RET_CHECK(mlir::succeeded(mlir::verify(new_op)));
output_shape = TypeToShape(outputs[0].getType());
HloFunctionImporter::SetLayoutForMlir(new_op, output_shape);
b.create<mlir::TensorStoreOp>(loc, new_op->getResult(0), outputs[0]);
}
input.op->erase();
input.op = fusion;
int unroll_factor = 1;
// TODO(timshen): Port MayPreventVectorization as we add more ops into this
// function.
if (output_shape.IsArray()) {
unroll_factor = ComputeMaxUnrollFactor(output_shape, hlo_module_config_);
}
auto ret = EmitLoopFusionFromMlir(input, output_shape, unroll_factor);
return ret;
return EmitLoopFusionFromMlir(input, output_shape);
}
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
@ -1210,8 +1288,7 @@ StatusOr<MlirEmitterInput> IrEmitterUnnested::GetMlirEmitterInput(
// This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
// subclass. The logic is de-virtualized and less scattered.
Status IrEmitterUnnested::EmitLoopFusionFromMlir(MlirEmitterInput input,
const Shape& output_shape,
int unroll_factor) {
const Shape& output_shape) {
auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(input.op);
MlirEmitterContext context;
context.SetOperation(fusion);
@ -1258,6 +1335,11 @@ Status IrEmitterUnnested::EmitLoopFusionFromMlir(MlirEmitterInput input,
auto element_generator,
fused_emitter.GetGenerator(fused_computation->root_instruction()));
int unroll_factor = 1;
if (!MayPreventVectorization(fusion)) {
unroll_factor = ComputeMaxUnrollFactor(fusion, hlo_module_config_);
}
Shape element_shape = context.output_shapes[0];
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
@ -1436,12 +1518,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
return Status::OK();
}
int unroll_factor = 1;
if (!MayPreventVectorization(*fusion)) {
unroll_factor = ComputeMaxUnrollFactor(fusion);
}
return EmitLoopFusionFromMlir(mlir_input, fusion->shape(), unroll_factor);
return EmitLoopFusionFromMlir(mlir_input, fusion->shape());
}
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
@ -1476,7 +1553,7 @@ Status IrEmitterUnnested::EmitCopyForMlir(MlirEmitterInput input) {
return Status::OK();
}
return DefaultActionForMlir(input);
return EmitUsingElementalIrEmitter(input);
}
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
@ -1507,7 +1584,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
return EmitReductionFromOrToContiguousDimensions(mlir_input);
}
return DefaultActionForMlir(mlir_input);
return EmitUsingElementalIrEmitter(mlir_input);
}
Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {

View File

@ -157,7 +157,7 @@ class IrEmitterUnnested : public IrEmitter,
}
Status DefaultAction(HloInstruction* hlo) override;
Status DefaultActionForMlir(MlirEmitterInput input);
Status EmitUsingElementalIrEmitter(MlirEmitterInput input);
// IrEmitterUnnested handles the following instructions differently from
// IrEmitter. It also mixes in some special handling for custom kernels
@ -175,7 +175,7 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleFft(HloInstruction* fft) override;
Status HandleFusion(HloInstruction* fusion) override;
Status EmitLoopFusionFromMlir(MlirEmitterInput input,
const Shape& output_shape, int unroll_factor);
const Shape& output_shape);
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleSelectAndScatter(HloInstruction* instruction) override;

File diff suppressed because it is too large Load Diff

View File

@ -56,7 +56,7 @@ void LlvmIrGenTestBase::CompileAndVerifyIr(
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
TF_ASSERT_OK(filecheck_result.status());
EXPECT_TRUE(filecheck_result.ValueOrDie());
EXPECT_TRUE(filecheck_result.ValueOrDie()) << "Full IR: " << ir_;
}
void LlvmIrGenTestBase::CompileAndVerifyIr(const string& hlo_text,
@ -80,7 +80,7 @@ void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr(
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
ASSERT_TRUE(filecheck_result.ok());
EXPECT_TRUE(filecheck_result.ValueOrDie());
EXPECT_TRUE(filecheck_result.ValueOrDie()) << "Full IR: " << ir_;
}
void LlvmIrGenTestBase::MatchOptimizedHlo(absl::string_view hlo,