Lower hlo::ConstOp to lhlo::ConstOp.

PiperOrigin-RevId: 285174603
Change-Id: Iaee707b3e2ef768e5574d5f9d9ce40692095e90c
This commit is contained in:
Alexander Belyaev 2019-12-12 05:50:04 -08:00 committed by TensorFlower Gardener
parent 7b311d7703
commit 27f3043b89
3 changed files with 20 additions and 0 deletions

View File

@ -118,6 +118,10 @@ Status HloDialectEmitter::HandleParameter(HloInstruction* param) {
}
Status HloDialectEmitter::HandleConstant(HloInstruction* constant) {
auto shape = constant->shape();
if (!shape.IsArray() || shape.rank() != 0) {
return Unimplemented("non-scalar constants are not supported yet");
}
TF_ASSIGN_OR_RETURN(auto type, ConvertTensorShapeToType<RankedTensorType>(
constant->shape(), builder_));

View File

@ -287,6 +287,21 @@ Status LhloDialectEmitter::HandleCompare(HloInstruction* compare) {
return Status::OK();
}
Status LhloDialectEmitter::HandleConstant(HloInstruction* constant) {
auto shape = constant->shape();
if (!shape.IsArray() || shape.rank() != 0) {
return Unimplemented("non-scalar constants are not supported yet");
}
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*constant));
OpBuilder func_builder(function.getBody());
TF_ASSIGN_OR_RETURN(auto value, CreateDenseElementsAttrFromLiteral(
constant->literal(), func_builder));
func_builder.create<lhlo::ConstOp>(getLocation(constant), value,
*function.args_begin());
return Status::OK();
}
Status LhloDialectEmitter::HandleIota(HloInstruction* iota) {
mlir::IntegerAttr iota_dim = builder_.getI64IntegerAttr(
static_cast<HloIotaInstruction*>(iota)->iota_dimension());

View File

@ -55,6 +55,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
Status DefaultAction(HloInstruction* instr) override;
Status HandleBroadcast(HloInstruction* broadcast) override;
Status HandleCompare(HloInstruction* compare) override;
Status HandleConstant(HloInstruction* constant) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleIota(HloInstruction* iota) override;