[XLA/GPU] Migrate unnested constants to mlir::GetGlobalMemrefOp.
PiperOrigin-RevId: 350839265 Change-Id: I69b1f61b9706446edb2c76b4ca5649f5f6ab4628
This commit is contained in:
parent
52a2da52ad
commit
f7f8dd1f1b
@ -67,6 +67,17 @@ StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
|
||||
makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyDenseElementsBy(mlir::DenseElementsAttr data,
|
||||
std::vector<uint8>* output) {
|
||||
output->resize(data.getNumElements() * sizeof(T));
|
||||
int i = 0;
|
||||
for (T element : data.getValues<T>()) {
|
||||
std::memcpy(&(*output)[i], &element, sizeof(T));
|
||||
i += sizeof(T);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<mlir::MemRefType> ConvertTensorShapeToMemRefType(
|
||||
@ -129,6 +140,75 @@ StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
|
||||
}
|
||||
}
|
||||
|
||||
Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
|
||||
std::vector<uint8>* output) {
|
||||
mlir::Type element_type = data.getType().getElementType();
|
||||
|
||||
// TODO(hinsu): Support remaining XLA primitive types.
|
||||
if (element_type.isInteger(1)) {
|
||||
CopyDenseElementsBy<bool>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
if (element_type.isInteger(8)) {
|
||||
CopyDenseElementsBy<uint8>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
if (element_type.isInteger(16)) {
|
||||
CopyDenseElementsBy<uint16>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
if (element_type.isInteger(32)) {
|
||||
CopyDenseElementsBy<uint32>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
if (element_type.isInteger(64)) {
|
||||
CopyDenseElementsBy<uint64>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
if (element_type.isBF16()) {
|
||||
CopyDenseElementsBy<bfloat16>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
if (element_type.isF16()) {
|
||||
CopyDenseElementsBy<half>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
if (element_type.isF32()) {
|
||||
CopyDenseElementsBy<float>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
if (element_type.isF64()) {
|
||||
CopyDenseElementsBy<double>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
|
||||
if (complex_type.getElementType().isF32()) {
|
||||
CopyDenseElementsBy<complex64>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
if (complex_type.getElementType().isF64()) {
|
||||
CopyDenseElementsBy<complex128>(data, output);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return tensorflow::errors::Internal(
|
||||
"Unsupported type in CopyDenseElementsDataToXlaFormat");
|
||||
}
|
||||
|
||||
StatusOr<int> GetElementTypeBytes(mlir::Type type) {
|
||||
if (type.isInteger(1)) {
|
||||
return 1;
|
||||
}
|
||||
if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
|
||||
TF_ASSIGN_OR_RETURN(int bytes,
|
||||
GetElementTypeBytes(complex_type.getElementType()));
|
||||
return bytes * 2;
|
||||
}
|
||||
int width = type.getIntOrFloatBitWidth();
|
||||
TF_RET_CHECK(width % 8 == 0);
|
||||
return width / 8;
|
||||
}
|
||||
|
||||
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
|
||||
const llvm::ArrayRef<int64> vector, mlir::Builder builder,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
|
@ -30,6 +30,11 @@ namespace xla {
|
||||
StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
|
||||
const LiteralBase& literal, mlir::Builder builder);
|
||||
|
||||
Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
|
||||
std::vector<uint8>* output);
|
||||
|
||||
StatusOr<int> GetElementTypeBytes(mlir::Type type);
|
||||
|
||||
// Creates an DenseIntElementsAttr using the elements of the vector and the
|
||||
// optional shape.
|
||||
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
|
||||
|
@ -279,18 +279,13 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
|
||||
TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle));
|
||||
|
||||
for (const auto& info : constants_) {
|
||||
const Literal& literal = info.content;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto global, executor->GetUntypedSymbol(
|
||||
info.symbol_name, module_handle));
|
||||
VLOG(3) << "Resolved global " << info.symbol_name << " to "
|
||||
<< global.opaque();
|
||||
|
||||
CHECK(literal.shape().IsArray());
|
||||
if (!ShouldEmitLiteralInLlvmIr(literal)) {
|
||||
VLOG(3) << "H2D memcpy for constant with shape "
|
||||
<< ShapeUtil::HumanString(literal.shape());
|
||||
stream->ThenMemcpy(&global, literal.untyped_data(), literal.size_bytes());
|
||||
if (!info.content.empty()) {
|
||||
stream->ThenMemcpy(&global, info.content.data(), info.content.size());
|
||||
}
|
||||
|
||||
if (info.allocation_index != -1) {
|
||||
|
@ -51,7 +51,7 @@ class GpuExecutable : public Executable {
|
||||
public:
|
||||
struct ConstantInfo {
|
||||
std::string symbol_name;
|
||||
xla::Literal content;
|
||||
std::vector<uint8> content;
|
||||
int allocation_index = -1;
|
||||
};
|
||||
|
||||
|
@ -145,7 +145,11 @@ Status IrEmitter::EmitConstants(const HloComputation& computation,
|
||||
|
||||
GpuExecutable::ConstantInfo info;
|
||||
info.symbol_name = global_name;
|
||||
info.content = literal.Clone();
|
||||
|
||||
if (!should_emit_initializer) {
|
||||
auto base = static_cast<const uint8*>(literal.untyped_data());
|
||||
info.content.assign(base, base + literal.size_bytes());
|
||||
}
|
||||
if (lookup_indices) {
|
||||
auto maybe_slice =
|
||||
ir_emitter_context_->buffer_assignment().GetUniqueSlice(instr, {});
|
||||
|
@ -769,6 +769,71 @@ Status IrEmitterUnnested::EmitUsingElementalIrEmitter(MlirEmitterInput input) {
|
||||
return EmitLoopFusionFromMlir(input, output_shape);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleConstant(HloInstruction* constant) {
|
||||
return Status::OK();
|
||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(constant));
|
||||
return EmitConstant(input);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitConstant(MlirEmitterInput mlir_input) {
|
||||
auto get_global = mlir::cast<mlir::GetGlobalMemrefOp>(mlir_input.op);
|
||||
auto module = get_global->getParentOfType<mlir::ModuleOp>();
|
||||
auto global =
|
||||
mlir::cast<mlir::GlobalMemrefOp>(module.lookupSymbol(get_global.name()));
|
||||
|
||||
auto literal = global.initial_value()->dyn_cast<mlir::DenseElementsAttr>();
|
||||
TF_RET_CHECK(literal);
|
||||
|
||||
const bool should_emit_initializer = literal.getType().getNumElements() > 1;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(int element_bytes,
|
||||
GetElementTypeBytes(literal.getType().getElementType()));
|
||||
llvm::ArrayType* global_type = llvm::ArrayType::get(
|
||||
b_.getInt8Ty(), literal.getType().getNumElements() * element_bytes);
|
||||
|
||||
GpuExecutable::ConstantInfo info;
|
||||
llvm::Constant* initializer;
|
||||
if (should_emit_initializer) {
|
||||
std::vector<uint8> content;
|
||||
TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content));
|
||||
initializer = llvm::ConstantDataArray::get<uint8>(
|
||||
ir_emitter_context_->llvm_module()->getContext(), content);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CopyDenseElementsDataToXlaFormat(literal, &info.content));
|
||||
initializer = llvm::ConstantAggregateZero::get(global_type);
|
||||
}
|
||||
|
||||
// These globals will be looked up by name by GpuExecutable so we need to
|
||||
// give them an external linkage. Not all of their uses are visible in
|
||||
// the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
|
||||
// merely preserves their names (like available_externally), we also need
|
||||
// to ensure that they stick around even if they're "unused".
|
||||
//
|
||||
// We may have to be more clever here in the future if we notice that we're
|
||||
// keeping around too many globals because of their linkage.
|
||||
unsigned global_address_space =
|
||||
llvm_ir::GetGlobalMemoryAddressSpace(*ir_emitter_context_->llvm_module());
|
||||
|
||||
llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
|
||||
global_type, /*isConstant=*/should_emit_initializer,
|
||||
llvm::GlobalValue::ExternalLinkage,
|
||||
/*Initializer=*/initializer, global.sym_name(),
|
||||
/*TLMode=*/llvm::GlobalValue::NotThreadLocal,
|
||||
/*AddressSpace=*/global_address_space,
|
||||
/*isExternallyInitialized=*/false);
|
||||
global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
|
||||
ir_emitter_context_->llvm_module()->getGlobalList().push_back(
|
||||
global_for_const);
|
||||
|
||||
info.symbol_name.assign(global.sym_name().begin(), global.sym_name().end());
|
||||
|
||||
info.allocation_index =
|
||||
global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
|
||||
ir_emitter_context_->constants().push_back(std::move(info));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
|
||||
TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional));
|
||||
AddThunkToThunkSequence(std::move(thunk));
|
||||
@ -3597,21 +3662,9 @@ IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Value init_value,
|
||||
}
|
||||
|
||||
if (const_init) {
|
||||
Shape init_shape = TypeToShape(init_value.getType());
|
||||
CHECK(ShapeUtil::IsScalar(init_shape));
|
||||
int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_shape);
|
||||
bool bool_init;
|
||||
absl::Span<const uint8> literal_bytes(
|
||||
reinterpret_cast<const uint8*>(const_init.getRawData().data()),
|
||||
num_bytes);
|
||||
auto init_type = init_value.getType().dyn_cast<mlir::ShapedType>();
|
||||
if (init_shape.element_type() == PRED) {
|
||||
TF_RET_CHECK(num_bytes == 1);
|
||||
TF_RET_CHECK(init_type.getElementTypeBitWidth() == 1);
|
||||
bool_init = *const_init.getBoolValues().begin();
|
||||
literal_bytes =
|
||||
absl::MakeSpan(reinterpret_cast<const uint8_t*>(&bool_init), 1);
|
||||
}
|
||||
std::vector<uint8> literal_bytes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSliceForMlir(dest));
|
||||
|
||||
|
@ -162,6 +162,9 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// IrEmitterUnnested handles the following instructions differently from
|
||||
// IrEmitter. It also mixes in some special handling for custom kernels
|
||||
// via the ThunkEmitter.
|
||||
Status HandleConstant(HloInstruction* constant) override;
|
||||
Status EmitConstant(MlirEmitterInput mlir_input);
|
||||
|
||||
Status HandleCopy(HloInstruction* copy) override;
|
||||
Status EmitCopyForMlir(MlirEmitterInput input);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user