Add infeed_config attribute to InfeedOp.
PiperOrigin-RevId: 286224973 Change-Id: If77849b23b0ae49188df7ceb464908a8515b49ce
This commit is contained in:
parent
8ae50866d7
commit
b35e03a73b
@ -287,6 +287,12 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||||||
llvm::ArrayRef<Value*>(operands.begin() + 2, operands.end()))
|
llvm::ArrayRef<Value*>(operands.begin() + 2, operands.end()))
|
||||||
.getOperation();
|
.getOperation();
|
||||||
}
|
}
|
||||||
|
case HloOpcode::kInfeed: {
|
||||||
|
attributes.push_back(builder_->getNamedAttr(
|
||||||
|
"infeed_config", mlir::StringAttr::get(instruction->infeed_config(),
|
||||||
|
builder_->getContext())));
|
||||||
|
MakeAndReturn(InfeedOp);
|
||||||
|
}
|
||||||
case HloOpcode::kPad: {
|
case HloOpcode::kPad: {
|
||||||
const auto& padding_config = instruction->padding_config();
|
const auto& padding_config = instruction->padding_config();
|
||||||
llvm::SmallVector<int64_t, 4> edge_padding_low;
|
llvm::SmallVector<int64_t, 4> edge_padding_low;
|
||||||
@ -448,7 +454,6 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||||||
NoAttributeCase(kExp, ExpOp);
|
NoAttributeCase(kExp, ExpOp);
|
||||||
NoAttributeCase(kExpm1, Expm1Op);
|
NoAttributeCase(kExpm1, Expm1Op);
|
||||||
NoAttributeCase(kFloor, FloorOp);
|
NoAttributeCase(kFloor, FloorOp);
|
||||||
NoAttributeCase(kInfeed, InfeedOp);
|
|
||||||
NoAttributeCase(kImag, ImagOp);
|
NoAttributeCase(kImag, ImagOp);
|
||||||
NoAttributeCase(kLog, LogOp);
|
NoAttributeCase(kLog, LogOp);
|
||||||
NoAttributeCase(kLog1p, Log1pOp);
|
NoAttributeCase(kLog1p, Log1pOp);
|
||||||
|
@ -343,7 +343,10 @@ def HLO_InfeedOp : HLO_Op<"infeed", []> {
|
|||||||
See https://www.tensorflow.org/xla/operation_semantics#infeed.
|
See https://www.tensorflow.org/xla/operation_semantics#infeed.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins HLO_Token:$token);
|
let arguments = (ins
|
||||||
|
HLO_Token:$token,
|
||||||
|
DefaultValuedAttr<StrAttr, "">:$infeed_config
|
||||||
|
);
|
||||||
let results = (outs HLO_Tuple);
|
let results = (outs HLO_Tuple);
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
@ -558,8 +558,8 @@ LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
|
|||||||
// The shape argument expected by the xla client API is the type of the first
|
// The shape argument expected by the xla client API is the type of the first
|
||||||
// element in the result tuple.
|
// element in the result tuple.
|
||||||
auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
|
auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
|
||||||
value_map[op] = xla::InfeedWithToken(value_map[op.token()],
|
value_map[op] = xla::InfeedWithToken(
|
||||||
xla::TypeToShape(result_type));
|
value_map[op.token()], xla::TypeToShape(result_type), op.infeed_config());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -398,13 +398,13 @@ func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
|||||||
|
|
||||||
// CHECK: HloModule
|
// CHECK: HloModule
|
||||||
func @main(%arg0: !xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token> {
|
func @main(%arg0: !xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token> {
|
||||||
%0 = "xla_hlo.infeed"(%arg0) : (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
|
%0 = "xla_hlo.infeed"(%arg0) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
|
||||||
return %0 : tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
|
return %0 : tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: ENTRY
|
// CHECK: ENTRY
|
||||||
// CHECK: [[ARG:%.*]] = token[] parameter(0)
|
// CHECK: [[ARG:%.*]] = token[] parameter(0)
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = ((s32[3], pred[]), token[]) infeed(token[] [[ARG]])
|
// CHECK: ROOT %[[RESULT:.*]] = ((s32[3], pred[]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar"
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -364,6 +364,16 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
|||||||
ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1)
|
ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @test_infeed
|
||||||
|
// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple<tensor<3xi32>, !xla_hlo.token> {
|
||||||
|
%test_infeed (token0: token[]) -> (s32[3], token[]) {
|
||||||
|
%token0 = token[] parameter(0)
|
||||||
|
// CHECK-NEXT: "xla_hlo.infeed"([[TOKEN]])
|
||||||
|
// CHECK-SAME: infeed_config = "foobar"
|
||||||
|
ROOT %infeed = (s32[3], token[]) infeed(token[] %token0), infeed_config="foobar"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @test_iota_1() -> tensor<4xf32> {
|
// CHECK-LABEL: func @test_iota_1() -> tensor<4xf32> {
|
||||||
%test_iota_1 () -> f32[4] {
|
%test_iota_1 () -> f32[4] {
|
||||||
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||||
|
Loading…
Reference in New Issue
Block a user