[XLA/GPU] Migrate unnested constants to mlir::GetGlobalMemrefOp.

PiperOrigin-RevId: 350839265
Change-Id: I69b1f61b9706446edb2c76b4ca5649f5f6ab4628
This commit is contained in:
Tim Shen 2021-01-08 14:39:16 -08:00 committed by TensorFlower Gardener
parent 52a2da52ad
commit f7f8dd1f1b
7 changed files with 164 additions and 24 deletions

View File

@ -67,6 +67,17 @@ StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
}
template <typename T>
void CopyDenseElementsBy(mlir::DenseElementsAttr data,
std::vector<uint8>* output) {
output->resize(data.getNumElements() * sizeof(T));
int i = 0;
for (T element : data.getValues<T>()) {
std::memcpy(&(*output)[i], &element, sizeof(T));
i += sizeof(T);
}
}
} // namespace
StatusOr<mlir::MemRefType> ConvertTensorShapeToMemRefType(
@ -129,6 +140,75 @@ StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
}
}
Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
std::vector<uint8>* output) {
mlir::Type element_type = data.getType().getElementType();
// TODO(hinsu): Support remaining XLA primitive types.
if (element_type.isInteger(1)) {
CopyDenseElementsBy<bool>(data, output);
return Status::OK();
}
if (element_type.isInteger(8)) {
CopyDenseElementsBy<uint8>(data, output);
return Status::OK();
}
if (element_type.isInteger(16)) {
CopyDenseElementsBy<uint16>(data, output);
return Status::OK();
}
if (element_type.isInteger(32)) {
CopyDenseElementsBy<uint32>(data, output);
return Status::OK();
}
if (element_type.isInteger(64)) {
CopyDenseElementsBy<uint64>(data, output);
return Status::OK();
}
if (element_type.isBF16()) {
CopyDenseElementsBy<bfloat16>(data, output);
return Status::OK();
}
if (element_type.isF16()) {
CopyDenseElementsBy<half>(data, output);
return Status::OK();
}
if (element_type.isF32()) {
CopyDenseElementsBy<float>(data, output);
return Status::OK();
}
if (element_type.isF64()) {
CopyDenseElementsBy<double>(data, output);
return Status::OK();
}
if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
if (complex_type.getElementType().isF32()) {
CopyDenseElementsBy<complex64>(data, output);
return Status::OK();
}
if (complex_type.getElementType().isF64()) {
CopyDenseElementsBy<complex128>(data, output);
return Status::OK();
}
}
return tensorflow::errors::Internal(
"Unsupported type in CopyDenseElementsDataToXlaFormat");
}
StatusOr<int> GetElementTypeBytes(mlir::Type type) {
if (type.isInteger(1)) {
return 1;
}
if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
TF_ASSIGN_OR_RETURN(int bytes,
GetElementTypeBytes(complex_type.getElementType()));
return bytes * 2;
}
int width = type.getIntOrFloatBitWidth();
TF_RET_CHECK(width % 8 == 0);
return width / 8;
}
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
const llvm::ArrayRef<int64> vector, mlir::Builder builder,
llvm::ArrayRef<int64_t> shape) {

View File

@ -30,6 +30,11 @@ namespace xla {
StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
const LiteralBase& literal, mlir::Builder builder);
Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
std::vector<uint8>* output);
StatusOr<int> GetElementTypeBytes(mlir::Type type);
// Creates an DenseIntElementsAttr using the elements of the vector and the
// optional shape.
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(

View File

@ -279,18 +279,13 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle));
for (const auto& info : constants_) {
const Literal& literal = info.content;
TF_ASSIGN_OR_RETURN(auto global, executor->GetUntypedSymbol(
info.symbol_name, module_handle));
VLOG(3) << "Resolved global " << info.symbol_name << " to "
<< global.opaque();
CHECK(literal.shape().IsArray());
if (!ShouldEmitLiteralInLlvmIr(literal)) {
VLOG(3) << "H2D memcpy for constant with shape "
<< ShapeUtil::HumanString(literal.shape());
stream->ThenMemcpy(&global, literal.untyped_data(), literal.size_bytes());
if (!info.content.empty()) {
stream->ThenMemcpy(&global, info.content.data(), info.content.size());
}
if (info.allocation_index != -1) {

View File

@ -51,7 +51,7 @@ class GpuExecutable : public Executable {
public:
struct ConstantInfo {
std::string symbol_name;
xla::Literal content;
std::vector<uint8> content;
int allocation_index = -1;
};

View File

@ -145,7 +145,11 @@ Status IrEmitter::EmitConstants(const HloComputation& computation,
GpuExecutable::ConstantInfo info;
info.symbol_name = global_name;
info.content = literal.Clone();
if (!should_emit_initializer) {
auto base = static_cast<const uint8*>(literal.untyped_data());
info.content.assign(base, base + literal.size_bytes());
}
if (lookup_indices) {
auto maybe_slice =
ir_emitter_context_->buffer_assignment().GetUniqueSlice(instr, {});

View File

@ -769,6 +769,71 @@ Status IrEmitterUnnested::EmitUsingElementalIrEmitter(MlirEmitterInput input) {
return EmitLoopFusionFromMlir(input, output_shape);
}
Status IrEmitterUnnested::HandleConstant(HloInstruction* constant) {
return Status::OK();
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(constant));
return EmitConstant(input);
}
Status IrEmitterUnnested::EmitConstant(MlirEmitterInput mlir_input) {
auto get_global = mlir::cast<mlir::GetGlobalMemrefOp>(mlir_input.op);
auto module = get_global->getParentOfType<mlir::ModuleOp>();
auto global =
mlir::cast<mlir::GlobalMemrefOp>(module.lookupSymbol(get_global.name()));
auto literal = global.initial_value()->dyn_cast<mlir::DenseElementsAttr>();
TF_RET_CHECK(literal);
const bool should_emit_initializer = literal.getType().getNumElements() > 1;
TF_ASSIGN_OR_RETURN(int element_bytes,
GetElementTypeBytes(literal.getType().getElementType()));
llvm::ArrayType* global_type = llvm::ArrayType::get(
b_.getInt8Ty(), literal.getType().getNumElements() * element_bytes);
GpuExecutable::ConstantInfo info;
llvm::Constant* initializer;
if (should_emit_initializer) {
std::vector<uint8> content;
TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content));
initializer = llvm::ConstantDataArray::get<uint8>(
ir_emitter_context_->llvm_module()->getContext(), content);
} else {
TF_RETURN_IF_ERROR(
CopyDenseElementsDataToXlaFormat(literal, &info.content));
initializer = llvm::ConstantAggregateZero::get(global_type);
}
// These globals will be looked up by name by GpuExecutable so we need to
// give them an external linkage. Not all of their uses are visible in
// the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
// merely preserves their names (like available_externally), we also need
// to ensure that they stick around even if they're "unused".
//
// We may have to be more clever here in the future if we notice that we're
// keeping around too many globals because of their linkage.
unsigned global_address_space =
llvm_ir::GetGlobalMemoryAddressSpace(*ir_emitter_context_->llvm_module());
llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
global_type, /*isConstant=*/should_emit_initializer,
llvm::GlobalValue::ExternalLinkage,
/*Initializer=*/initializer, global.sym_name(),
/*TLMode=*/llvm::GlobalValue::NotThreadLocal,
/*AddressSpace=*/global_address_space,
/*isExternallyInitialized=*/false);
global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
ir_emitter_context_->llvm_module()->getGlobalList().push_back(
global_for_const);
info.symbol_name.assign(global.sym_name().begin(), global.sym_name().end());
info.allocation_index =
global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
ir_emitter_context_->constants().push_back(std::move(info));
return Status::OK();
}
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional));
AddThunkToThunkSequence(std::move(thunk));
@ -3597,21 +3662,9 @@ IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Value init_value,
}
if (const_init) {
Shape init_shape = TypeToShape(init_value.getType());
CHECK(ShapeUtil::IsScalar(init_shape));
int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_shape);
bool bool_init;
absl::Span<const uint8> literal_bytes(
reinterpret_cast<const uint8*>(const_init.getRawData().data()),
num_bytes);
auto init_type = init_value.getType().dyn_cast<mlir::ShapedType>();
if (init_shape.element_type() == PRED) {
TF_RET_CHECK(num_bytes == 1);
TF_RET_CHECK(init_type.getElementTypeBitWidth() == 1);
bool_init = *const_init.getBoolValues().begin();
literal_bytes =
absl::MakeSpan(reinterpret_cast<const uint8_t*>(&bool_init), 1);
}
std::vector<uint8> literal_bytes;
TF_RETURN_IF_ERROR(
CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes));
TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSliceForMlir(dest));

View File

@ -162,6 +162,9 @@ class IrEmitterUnnested : public IrEmitter,
// IrEmitterUnnested handles the following instructions differently from
// IrEmitter. It also mixes in some special handling for custom kernels
// via the ThunkEmitter.
Status HandleConstant(HloInstruction* constant) override;
Status EmitConstant(MlirEmitterInput mlir_input);
Status HandleCopy(HloInstruction* copy) override;
Status EmitCopyForMlir(MlirEmitterInput input);