Add layout to mhlo::InfeedOp td.
PiperOrigin-RevId: 356286875 Change-Id: I78ecebe20eb4bbbfc50f2cb0be22ef930ed1355d
This commit is contained in:
parent
ed77a63244
commit
4afbaca02c
@ -401,12 +401,18 @@ def HLO_InfeedOp : HLO_Op<"infeed", []> {
|
|||||||
of the data. Multiple Infeed operations are allowed in a computation, but
|
of the data. Multiple Infeed operations are allowed in a computation, but
|
||||||
there must be a total order among the Infeed operations.
|
there must be a total order among the Infeed operations.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
layout: Array attribute. Same shape as the output of the infeed, except
|
||||||
|
that every tensor is replaced by a minor_to_major array for the
|
||||||
|
tensor's layout.
|
||||||
|
|
||||||
See https://www.tensorflow.org/xla/operation_semantics#infeed.
|
See https://www.tensorflow.org/xla/operation_semantics#infeed.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Token:$token,
|
HLO_Token:$token,
|
||||||
DefaultValuedAttr<StrAttr, "">:$infeed_config
|
DefaultValuedAttr<StrAttr, "">:$infeed_config,
|
||||||
|
OptionalAttr<ArrayAttr>:$layout
|
||||||
);
|
);
|
||||||
let results = (outs HLO_Tuple);
|
let results = (outs HLO_Tuple);
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
|
@ -442,7 +442,7 @@ func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -
|
|||||||
|
|
||||||
func @infeed_invalid_number_of_results(%token: !mhlo.token) -> tuple<tuple<tensor<i32>>, !mhlo.token, tensor<i32>> {
|
func @infeed_invalid_number_of_results(%token: !mhlo.token) -> tuple<tuple<tensor<i32>>, !mhlo.token, tensor<i32>> {
|
||||||
// expected-error@+1 {{result is expected to be a tuple of size 2, but got 3}}
|
// expected-error@+1 {{result is expected to be a tuple of size 2, but got 3}}
|
||||||
%0 = "mhlo.infeed"(%token) {infeed_config = "foobar"} : (!mhlo.token) -> tuple<tuple<tensor<i32>>, !mhlo.token, tensor<i32>>
|
%0 = "mhlo.infeed"(%token) {infeed_config = "foobar", layout = [[[0]], unit, [0]]} : (!mhlo.token) -> tuple<tuple<tensor<i32>>, !mhlo.token, tensor<i32>>
|
||||||
return %0 : tuple<tuple<tensor<i32>>, !mhlo.token, tensor<i32>>
|
return %0 : tuple<tuple<tensor<i32>>, !mhlo.token, tensor<i32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -450,7 +450,7 @@ func @infeed_invalid_number_of_results(%token: !mhlo.token) -> tuple<tuple<tenso
|
|||||||
|
|
||||||
func @infeed_non_token_second_result(%token: !mhlo.token) -> tuple<tuple<tensor<i32>>, tensor<i32>> {
|
func @infeed_non_token_second_result(%token: !mhlo.token) -> tuple<tuple<tensor<i32>>, tensor<i32>> {
|
||||||
// expected-error@+1 {{second element of result tuple is expected to be of token type, but got 'tensor<i32>'}}
|
// expected-error@+1 {{second element of result tuple is expected to be of token type, but got 'tensor<i32>'}}
|
||||||
%0 = "mhlo.infeed"(%token) {infeed_config = "foobar"} : (!mhlo.token) -> tuple<tuple<tensor<i32>>, tensor<i32>>
|
%0 = "mhlo.infeed"(%token) {infeed_config = "foobar", layout = [[[0]], [0]]} : (!mhlo.token) -> tuple<tuple<tensor<i32>>, tensor<i32>>
|
||||||
return %0 : tuple<tuple<tensor<i32>>, tensor<i32>>
|
return %0 : tuple<tuple<tensor<i32>>, tensor<i32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -383,6 +383,19 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
|||||||
"infeed_config",
|
"infeed_config",
|
||||||
mlir::StringAttr::get(builder_->getContext(),
|
mlir::StringAttr::get(builder_->getContext(),
|
||||||
instruction->infeed_config())));
|
instruction->infeed_config())));
|
||||||
|
// TODO(kramm): Support tuples and tokens.
|
||||||
|
if (instruction->shape().IsArray()) {
|
||||||
|
const xla::Layout l = instruction->shape().layout();
|
||||||
|
absl::Span<const int64> minor_to_major = l.minor_to_major();
|
||||||
|
std::vector<mlir::Attribute> v;
|
||||||
|
for (int64 i : minor_to_major) {
|
||||||
|
v.push_back(builder_->getI32IntegerAttr(i));
|
||||||
|
}
|
||||||
|
llvm::ArrayRef<mlir::Attribute> array_ref(v);
|
||||||
|
mlir::ArrayAttr layout = builder_->getArrayAttr(array_ref);
|
||||||
|
attributes.push_back(builder_->getNamedAttr("layout", layout));
|
||||||
|
}
|
||||||
|
|
||||||
MakeAndReturn(InfeedOp);
|
MakeAndReturn(InfeedOp);
|
||||||
}
|
}
|
||||||
case HloOpcode::kOutfeed: {
|
case HloOpcode::kOutfeed: {
|
||||||
|
@ -466,9 +466,11 @@ StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
|
|||||||
TF_ASSIGN_OR_RETURN(mlir::Type result_type,
|
TF_ASSIGN_OR_RETURN(mlir::Type result_type,
|
||||||
ConvertShapeToType<mlir::RankedTensorType>(
|
ConvertShapeToType<mlir::RankedTensorType>(
|
||||||
infeed_instruction_shape, builder_));
|
infeed_instruction_shape, builder_));
|
||||||
|
mlir::ArrayAttr layout;
|
||||||
return MakeXlaOp(
|
return MakeXlaOp(
|
||||||
builder_.create<mlir::mhlo::InfeedOp>(loc_, result_type, GetValue(token),
|
builder_.create<mlir::mhlo::InfeedOp>(loc_, result_type, GetValue(token),
|
||||||
/*infeed_config=*/config));
|
/*infeed_config=*/config,
|
||||||
|
/*layout=*/layout));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
|
StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
|
||||||
|
@ -4613,9 +4613,15 @@ class ConvertInfeedDequeueTupleOp
|
|||||||
auto data_and_token_type = mlir::TupleType::get(
|
auto data_and_token_type = mlir::TupleType::get(
|
||||||
rewriter.getContext(), {data_tuple_type, token.getType()});
|
rewriter.getContext(), {data_tuple_type, token.getType()});
|
||||||
|
|
||||||
|
ArrayAttr layout =
|
||||||
|
GetLayout(data_and_token_type, rewriter).cast<ArrayAttr>();
|
||||||
auto data_and_token =
|
auto data_and_token =
|
||||||
rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,
|
rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,
|
||||||
/*infeed_config=*/rewriter.getStringAttr(""));
|
/*infeed_config=*/rewriter.getStringAttr(""),
|
||||||
|
/*layout=*/layout);
|
||||||
|
|
||||||
|
// TODO(b/171212005): Reenable layout.
|
||||||
|
data_and_token.removeAttr("layout");
|
||||||
|
|
||||||
if (op._XlaSharding().hasValue()) {
|
if (op._XlaSharding().hasValue()) {
|
||||||
// _XlaSharding attribute in TF is a serialized string of the OpSharding
|
// _XlaSharding attribute in TF is a serialized string of the OpSharding
|
||||||
|
Loading…
x
Reference in New Issue
Block a user