Lower hlo::ConstOp to lhlo::ConstOp.
PiperOrigin-RevId: 285174603 Change-Id: Iaee707b3e2ef768e5574d5f9d9ce40692095e90c
This commit is contained in:
parent
7b311d7703
commit
27f3043b89
@ -118,6 +118,10 @@ Status HloDialectEmitter::HandleParameter(HloInstruction* param) {
|
||||
}
|
||||
|
||||
Status HloDialectEmitter::HandleConstant(HloInstruction* constant) {
|
||||
auto shape = constant->shape();
|
||||
if (!shape.IsArray() || shape.rank() != 0) {
|
||||
return Unimplemented("non-scalar constants are not supported yet");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto type, ConvertTensorShapeToType<RankedTensorType>(
|
||||
constant->shape(), builder_));
|
||||
|
||||
|
@ -287,6 +287,21 @@ Status LhloDialectEmitter::HandleCompare(HloInstruction* compare) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleConstant(HloInstruction* constant) {
|
||||
auto shape = constant->shape();
|
||||
if (!shape.IsArray() || shape.rank() != 0) {
|
||||
return Unimplemented("non-scalar constants are not supported yet");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*constant));
|
||||
OpBuilder func_builder(function.getBody());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto value, CreateDenseElementsAttrFromLiteral(
|
||||
constant->literal(), func_builder));
|
||||
func_builder.create<lhlo::ConstOp>(getLocation(constant), value,
|
||||
*function.args_begin());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleIota(HloInstruction* iota) {
|
||||
mlir::IntegerAttr iota_dim = builder_.getI64IntegerAttr(
|
||||
static_cast<HloIotaInstruction*>(iota)->iota_dimension());
|
||||
|
@ -55,6 +55,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
|
||||
Status DefaultAction(HloInstruction* instr) override;
|
||||
Status HandleBroadcast(HloInstruction* broadcast) override;
|
||||
Status HandleCompare(HloInstruction* compare) override;
|
||||
Status HandleConstant(HloInstruction* constant) override;
|
||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||
Status HandleFusion(HloInstruction* fusion) override;
|
||||
Status HandleIota(HloInstruction* iota) override;
|
||||
|
Loading…
x
Reference in New Issue
Block a user