[XLA/GPU] Migrate unnested constants to mlir::GetGlobalMemrefOp.
PiperOrigin-RevId: 350839265 Change-Id: I69b1f61b9706446edb2c76b4ca5649f5f6ab4628
This commit is contained in:
parent
52a2da52ad
commit
f7f8dd1f1b
@ -67,6 +67,17 @@ StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
|
|||||||
makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
|
makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CopyDenseElementsBy(mlir::DenseElementsAttr data,
|
||||||
|
std::vector<uint8>* output) {
|
||||||
|
output->resize(data.getNumElements() * sizeof(T));
|
||||||
|
int i = 0;
|
||||||
|
for (T element : data.getValues<T>()) {
|
||||||
|
std::memcpy(&(*output)[i], &element, sizeof(T));
|
||||||
|
i += sizeof(T);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
StatusOr<mlir::MemRefType> ConvertTensorShapeToMemRefType(
|
StatusOr<mlir::MemRefType> ConvertTensorShapeToMemRefType(
|
||||||
@ -129,6 +140,75 @@ StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
|
||||||
|
std::vector<uint8>* output) {
|
||||||
|
mlir::Type element_type = data.getType().getElementType();
|
||||||
|
|
||||||
|
// TODO(hinsu): Support remaining XLA primitive types.
|
||||||
|
if (element_type.isInteger(1)) {
|
||||||
|
CopyDenseElementsBy<bool>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (element_type.isInteger(8)) {
|
||||||
|
CopyDenseElementsBy<uint8>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (element_type.isInteger(16)) {
|
||||||
|
CopyDenseElementsBy<uint16>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (element_type.isInteger(32)) {
|
||||||
|
CopyDenseElementsBy<uint32>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (element_type.isInteger(64)) {
|
||||||
|
CopyDenseElementsBy<uint64>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (element_type.isBF16()) {
|
||||||
|
CopyDenseElementsBy<bfloat16>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (element_type.isF16()) {
|
||||||
|
CopyDenseElementsBy<half>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (element_type.isF32()) {
|
||||||
|
CopyDenseElementsBy<float>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (element_type.isF64()) {
|
||||||
|
CopyDenseElementsBy<double>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
|
||||||
|
if (complex_type.getElementType().isF32()) {
|
||||||
|
CopyDenseElementsBy<complex64>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (complex_type.getElementType().isF64()) {
|
||||||
|
CopyDenseElementsBy<complex128>(data, output);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tensorflow::errors::Internal(
|
||||||
|
"Unsupported type in CopyDenseElementsDataToXlaFormat");
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<int> GetElementTypeBytes(mlir::Type type) {
|
||||||
|
if (type.isInteger(1)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(int bytes,
|
||||||
|
GetElementTypeBytes(complex_type.getElementType()));
|
||||||
|
return bytes * 2;
|
||||||
|
}
|
||||||
|
int width = type.getIntOrFloatBitWidth();
|
||||||
|
TF_RET_CHECK(width % 8 == 0);
|
||||||
|
return width / 8;
|
||||||
|
}
|
||||||
|
|
||||||
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
|
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
|
||||||
const llvm::ArrayRef<int64> vector, mlir::Builder builder,
|
const llvm::ArrayRef<int64> vector, mlir::Builder builder,
|
||||||
llvm::ArrayRef<int64_t> shape) {
|
llvm::ArrayRef<int64_t> shape) {
|
||||||
|
|||||||
@ -30,6 +30,11 @@ namespace xla {
|
|||||||
StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
|
StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
|
||||||
const LiteralBase& literal, mlir::Builder builder);
|
const LiteralBase& literal, mlir::Builder builder);
|
||||||
|
|
||||||
|
Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
|
||||||
|
std::vector<uint8>* output);
|
||||||
|
|
||||||
|
StatusOr<int> GetElementTypeBytes(mlir::Type type);
|
||||||
|
|
||||||
// Creates an DenseIntElementsAttr using the elements of the vector and the
|
// Creates an DenseIntElementsAttr using the elements of the vector and the
|
||||||
// optional shape.
|
// optional shape.
|
||||||
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
|
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
|
||||||
|
|||||||
@ -279,18 +279,13 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
|
|||||||
TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle));
|
TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle));
|
||||||
|
|
||||||
for (const auto& info : constants_) {
|
for (const auto& info : constants_) {
|
||||||
const Literal& literal = info.content;
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto global, executor->GetUntypedSymbol(
|
TF_ASSIGN_OR_RETURN(auto global, executor->GetUntypedSymbol(
|
||||||
info.symbol_name, module_handle));
|
info.symbol_name, module_handle));
|
||||||
VLOG(3) << "Resolved global " << info.symbol_name << " to "
|
VLOG(3) << "Resolved global " << info.symbol_name << " to "
|
||||||
<< global.opaque();
|
<< global.opaque();
|
||||||
|
|
||||||
CHECK(literal.shape().IsArray());
|
if (!info.content.empty()) {
|
||||||
if (!ShouldEmitLiteralInLlvmIr(literal)) {
|
stream->ThenMemcpy(&global, info.content.data(), info.content.size());
|
||||||
VLOG(3) << "H2D memcpy for constant with shape "
|
|
||||||
<< ShapeUtil::HumanString(literal.shape());
|
|
||||||
stream->ThenMemcpy(&global, literal.untyped_data(), literal.size_bytes());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (info.allocation_index != -1) {
|
if (info.allocation_index != -1) {
|
||||||
|
|||||||
@ -51,7 +51,7 @@ class GpuExecutable : public Executable {
|
|||||||
public:
|
public:
|
||||||
struct ConstantInfo {
|
struct ConstantInfo {
|
||||||
std::string symbol_name;
|
std::string symbol_name;
|
||||||
xla::Literal content;
|
std::vector<uint8> content;
|
||||||
int allocation_index = -1;
|
int allocation_index = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -145,7 +145,11 @@ Status IrEmitter::EmitConstants(const HloComputation& computation,
|
|||||||
|
|
||||||
GpuExecutable::ConstantInfo info;
|
GpuExecutable::ConstantInfo info;
|
||||||
info.symbol_name = global_name;
|
info.symbol_name = global_name;
|
||||||
info.content = literal.Clone();
|
|
||||||
|
if (!should_emit_initializer) {
|
||||||
|
auto base = static_cast<const uint8*>(literal.untyped_data());
|
||||||
|
info.content.assign(base, base + literal.size_bytes());
|
||||||
|
}
|
||||||
if (lookup_indices) {
|
if (lookup_indices) {
|
||||||
auto maybe_slice =
|
auto maybe_slice =
|
||||||
ir_emitter_context_->buffer_assignment().GetUniqueSlice(instr, {});
|
ir_emitter_context_->buffer_assignment().GetUniqueSlice(instr, {});
|
||||||
|
|||||||
@ -769,6 +769,71 @@ Status IrEmitterUnnested::EmitUsingElementalIrEmitter(MlirEmitterInput input) {
|
|||||||
return EmitLoopFusionFromMlir(input, output_shape);
|
return EmitLoopFusionFromMlir(input, output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status IrEmitterUnnested::HandleConstant(HloInstruction* constant) {
|
||||||
|
return Status::OK();
|
||||||
|
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(constant));
|
||||||
|
return EmitConstant(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status IrEmitterUnnested::EmitConstant(MlirEmitterInput mlir_input) {
|
||||||
|
auto get_global = mlir::cast<mlir::GetGlobalMemrefOp>(mlir_input.op);
|
||||||
|
auto module = get_global->getParentOfType<mlir::ModuleOp>();
|
||||||
|
auto global =
|
||||||
|
mlir::cast<mlir::GlobalMemrefOp>(module.lookupSymbol(get_global.name()));
|
||||||
|
|
||||||
|
auto literal = global.initial_value()->dyn_cast<mlir::DenseElementsAttr>();
|
||||||
|
TF_RET_CHECK(literal);
|
||||||
|
|
||||||
|
const bool should_emit_initializer = literal.getType().getNumElements() > 1;
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(int element_bytes,
|
||||||
|
GetElementTypeBytes(literal.getType().getElementType()));
|
||||||
|
llvm::ArrayType* global_type = llvm::ArrayType::get(
|
||||||
|
b_.getInt8Ty(), literal.getType().getNumElements() * element_bytes);
|
||||||
|
|
||||||
|
GpuExecutable::ConstantInfo info;
|
||||||
|
llvm::Constant* initializer;
|
||||||
|
if (should_emit_initializer) {
|
||||||
|
std::vector<uint8> content;
|
||||||
|
TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content));
|
||||||
|
initializer = llvm::ConstantDataArray::get<uint8>(
|
||||||
|
ir_emitter_context_->llvm_module()->getContext(), content);
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CopyDenseElementsDataToXlaFormat(literal, &info.content));
|
||||||
|
initializer = llvm::ConstantAggregateZero::get(global_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// These globals will be looked up by name by GpuExecutable so we need to
|
||||||
|
// give them an external linkage. Not all of their uses are visible in
|
||||||
|
// the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
|
||||||
|
// merely preserves their names (like available_externally), we also need
|
||||||
|
// to ensure that they stick around even if they're "unused".
|
||||||
|
//
|
||||||
|
// We may have to be more clever here in the future if we notice that we're
|
||||||
|
// keeping around too many globals because of their linkage.
|
||||||
|
unsigned global_address_space =
|
||||||
|
llvm_ir::GetGlobalMemoryAddressSpace(*ir_emitter_context_->llvm_module());
|
||||||
|
|
||||||
|
llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
|
||||||
|
global_type, /*isConstant=*/should_emit_initializer,
|
||||||
|
llvm::GlobalValue::ExternalLinkage,
|
||||||
|
/*Initializer=*/initializer, global.sym_name(),
|
||||||
|
/*TLMode=*/llvm::GlobalValue::NotThreadLocal,
|
||||||
|
/*AddressSpace=*/global_address_space,
|
||||||
|
/*isExternallyInitialized=*/false);
|
||||||
|
global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
|
||||||
|
ir_emitter_context_->llvm_module()->getGlobalList().push_back(
|
||||||
|
global_for_const);
|
||||||
|
|
||||||
|
info.symbol_name.assign(global.sym_name().begin(), global.sym_name().end());
|
||||||
|
|
||||||
|
info.allocation_index =
|
||||||
|
global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
|
||||||
|
ir_emitter_context_->constants().push_back(std::move(info));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
|
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
|
||||||
TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional));
|
TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional));
|
||||||
AddThunkToThunkSequence(std::move(thunk));
|
AddThunkToThunkSequence(std::move(thunk));
|
||||||
@ -3597,21 +3662,9 @@ IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Value init_value,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (const_init) {
|
if (const_init) {
|
||||||
Shape init_shape = TypeToShape(init_value.getType());
|
std::vector<uint8> literal_bytes;
|
||||||
CHECK(ShapeUtil::IsScalar(init_shape));
|
TF_RETURN_IF_ERROR(
|
||||||
int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_shape);
|
CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes));
|
||||||
bool bool_init;
|
|
||||||
absl::Span<const uint8> literal_bytes(
|
|
||||||
reinterpret_cast<const uint8*>(const_init.getRawData().data()),
|
|
||||||
num_bytes);
|
|
||||||
auto init_type = init_value.getType().dyn_cast<mlir::ShapedType>();
|
|
||||||
if (init_shape.element_type() == PRED) {
|
|
||||||
TF_RET_CHECK(num_bytes == 1);
|
|
||||||
TF_RET_CHECK(init_type.getElementTypeBitWidth() == 1);
|
|
||||||
bool_init = *const_init.getBoolValues().begin();
|
|
||||||
literal_bytes =
|
|
||||||
absl::MakeSpan(reinterpret_cast<const uint8_t*>(&bool_init), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSliceForMlir(dest));
|
TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSliceForMlir(dest));
|
||||||
|
|
||||||
|
|||||||
@ -162,6 +162,9 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
// IrEmitterUnnested handles the following instructions differently from
|
// IrEmitterUnnested handles the following instructions differently from
|
||||||
// IrEmitter. It also mixes in some special handling for custom kernels
|
// IrEmitter. It also mixes in some special handling for custom kernels
|
||||||
// via the ThunkEmitter.
|
// via the ThunkEmitter.
|
||||||
|
Status HandleConstant(HloInstruction* constant) override;
|
||||||
|
Status EmitConstant(MlirEmitterInput mlir_input);
|
||||||
|
|
||||||
Status HandleCopy(HloInstruction* copy) override;
|
Status HandleCopy(HloInstruction* copy) override;
|
||||||
Status EmitCopyForMlir(MlirEmitterInput input);
|
Status EmitCopyForMlir(MlirEmitterInput input);
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user