Add infeed_config attribute to InfeedOp.
PiperOrigin-RevId: 286224973 Change-Id: If77849b23b0ae49188df7ceb464908a8515b49ce
This commit is contained in:
parent
8ae50866d7
commit
b35e03a73b
@ -287,6 +287,12 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
llvm::ArrayRef<Value*>(operands.begin() + 2, operands.end()))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kInfeed: {
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"infeed_config", mlir::StringAttr::get(instruction->infeed_config(),
|
||||
builder_->getContext())));
|
||||
MakeAndReturn(InfeedOp);
|
||||
}
|
||||
case HloOpcode::kPad: {
|
||||
const auto& padding_config = instruction->padding_config();
|
||||
llvm::SmallVector<int64_t, 4> edge_padding_low;
|
||||
@ -448,7 +454,6 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
NoAttributeCase(kExp, ExpOp);
|
||||
NoAttributeCase(kExpm1, Expm1Op);
|
||||
NoAttributeCase(kFloor, FloorOp);
|
||||
NoAttributeCase(kInfeed, InfeedOp);
|
||||
NoAttributeCase(kImag, ImagOp);
|
||||
NoAttributeCase(kLog, LogOp);
|
||||
NoAttributeCase(kLog1p, Log1pOp);
|
||||
|
@ -343,7 +343,10 @@ def HLO_InfeedOp : HLO_Op<"infeed", []> {
|
||||
See https://www.tensorflow.org/xla/operation_semantics#infeed.
|
||||
}];
|
||||
|
||||
let arguments = (ins HLO_Token:$token);
|
||||
let arguments = (ins
|
||||
HLO_Token:$token,
|
||||
DefaultValuedAttr<StrAttr, "">:$infeed_config
|
||||
);
|
||||
let results = (outs HLO_Tuple);
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
@ -558,8 +558,8 @@ LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
|
||||
// The shape argument expected by the xla client API is the type of the first
|
||||
// element in the result tuple.
|
||||
auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
|
||||
value_map[op] = xla::InfeedWithToken(value_map[op.token()],
|
||||
xla::TypeToShape(result_type));
|
||||
value_map[op] = xla::InfeedWithToken(
|
||||
value_map[op.token()], xla::TypeToShape(result_type), op.infeed_config());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -398,13 +398,13 @@ func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: !xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token> {
|
||||
%0 = "xla_hlo.infeed"(%arg0) : (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
|
||||
return %0 : tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
|
||||
}
|
||||
%0 = "xla_hlo.infeed"(%arg0) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
|
||||
return %0 : tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
// CHECK: [[ARG:%.*]] = token[] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = ((s32[3], pred[]), token[]) infeed(token[] [[ARG]])
|
||||
// CHECK: ROOT %[[RESULT:.*]] = ((s32[3], pred[]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar"
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -364,6 +364,16 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_infeed
|
||||
// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple<tensor<3xi32>, !xla_hlo.token> {
|
||||
%test_infeed (token0: token[]) -> (s32[3], token[]) {
|
||||
%token0 = token[] parameter(0)
|
||||
// CHECK-NEXT: "xla_hlo.infeed"([[TOKEN]])
|
||||
// CHECK-SAME: infeed_config = "foobar"
|
||||
ROOT %infeed = (s32[3], token[]) infeed(token[] %token0), infeed_config="foobar"
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @test_iota_1() -> tensor<4xf32> {
|
||||
%test_iota_1 () -> f32[4] {
|
||||
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
|
Loading…
Reference in New Issue
Block a user