Build DynamicSlice and DynamicUpdateSlice ops with MlirHloBuilder
Whitelist XlaDynamicSlice and XlaDynamicUpdateSlice for testing PiperOrigin-RevId: 311642899 Change-Id: Icbf009cf69d3b183d0c83c10925a5fbaa3c49f1f
This commit is contained in:
parent
a2ef8b5a06
commit
4662933489
|
@ -282,6 +282,28 @@ StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
|
|||
GetI64ElementsAttr(strides, &builder_)));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
|
||||
absl::Span<const int64> slice_sizes) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type result_ty,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicSliceOp>(
|
||||
loc_, result_ty, GetValue(operand), GetValues(start_indices),
|
||||
GetI64ElementsAttr(slice_sizes, &builder_)));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp update,
|
||||
absl::Span<const XlaOp> start_indices) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type result_ty,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicUpdateSliceOp>(
|
||||
loc_, result_ty, GetValue(operand), GetValue(update),
|
||||
GetValues(start_indices)));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::PadInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp padding_value,
|
||||
const PaddingConfig& padding_config) {
|
||||
|
|
|
@ -175,6 +175,14 @@ class MlirHloBuilder : public XlaBuilder {
|
|||
absl::Span<const int64> limit_indices,
|
||||
absl::Span<const int64> strides) override;
|
||||
|
||||
StatusOr<XlaOp> DynamicSliceInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
|
||||
absl::Span<const int64> slice_sizes) override;
|
||||
|
||||
StatusOr<XlaOp> DynamicUpdateSliceInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp update,
|
||||
absl::Span<const XlaOp> start_indices) override;
|
||||
|
||||
StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
|
||||
XlaOp padding_value,
|
||||
const PaddingConfig& padding_config) override;
|
||||
|
|
|
@ -163,6 +163,30 @@ func @truncated_normal() -> tensor<2x2xf32> {
|
|||
return %1 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_update_slice
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32>
|
||||
func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> {
|
||||
|
||||
// CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]])
|
||||
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
|
||||
// CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
|
||||
|
||||
// CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]])
|
||||
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
|
||||
// CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor<i32>
|
||||
|
||||
// CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]])
|
||||
|
||||
%0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
|
||||
return %0: tensor<3x4xi32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
|
|
@ -168,6 +168,8 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||
TypeID::get<TF::XlaBroadcastHelperOp>(),
|
||||
TypeID::get<TF::XlaConvOp>(),
|
||||
TypeID::get<TF::XlaDotOp>(),
|
||||
TypeID::get<TF::XlaDynamicSliceOp>(),
|
||||
TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
|
||||
TypeID::get<TF::XlaPadOp>(),
|
||||
TypeID::get<TF::Xlog1pyOp>(),
|
||||
TypeID::get<TF::XlogyOp>()
|
||||
|
|
|
@ -304,7 +304,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||
self._assertOpOutputMatchesExpected(
|
||||
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDynamicSlice(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
|
@ -317,7 +316,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||
[[673, 674], [683, 684], [693, 694]]]),
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
@test_util.disable_mlir_bridge('Error handling')
|
||||
def testDynamicSliceWithIncorrectStartIndicesShape(self):
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
|
@ -331,7 +330,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||
(r'start_indices must be a vector with length equal to input rank, '
|
||||
r'but input rank is 3 and start_indices has shape \[2\].*'))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
@test_util.disable_mlir_bridge('Error handling')
|
||||
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
|
|
|
@ -864,8 +864,6 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
|
|||
absl::Span<const XlaOp> start_indices,
|
||||
absl::Span<const int64> slice_sizes) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
std::vector<const Shape*> start_indices_shape_ptrs;
|
||||
TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
|
||||
|
@ -876,23 +874,28 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
|
|||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferDynamicSliceShape(
|
||||
*operand_shape, start_indices_shapes, slice_sizes));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
for (int64 size : slice_sizes) {
|
||||
instr.add_dynamic_slice_sizes(size);
|
||||
}
|
||||
|
||||
std::vector<XlaOp> operands = {operand};
|
||||
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
|
||||
return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
|
||||
return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
|
||||
absl::Span<const int64> slice_sizes) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
for (int64 size : slice_sizes) {
|
||||
instr.add_dynamic_slice_sizes(size);
|
||||
}
|
||||
|
||||
std::vector<XlaOp> operands = {operand};
|
||||
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
|
||||
return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
|
||||
absl::Span<const XlaOp> start_indices) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
|
||||
std::vector<const Shape*> start_indices_shape_ptrs;
|
||||
|
@ -904,15 +907,22 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
|
||||
*operand_shape, *update_shape, start_indices_shapes));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
std::vector<XlaOp> operands = {operand, update};
|
||||
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
|
||||
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
|
||||
operands);
|
||||
return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp update,
|
||||
absl::Span<const XlaOp> start_indices) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
std::vector<XlaOp> operands = {operand, update};
|
||||
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
|
||||
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
|
||||
operands);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
|
||||
int64 dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
|
|
|
@ -423,9 +423,15 @@ class XlaBuilder {
|
|||
|
||||
XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
|
||||
absl::Span<const int64> slice_sizes);
|
||||
virtual StatusOr<XlaOp> DynamicSliceInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
|
||||
absl::Span<const int64> slice_sizes);
|
||||
|
||||
XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
|
||||
absl::Span<const XlaOp> start_indices);
|
||||
virtual StatusOr<XlaOp> DynamicUpdateSliceInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp update,
|
||||
absl::Span<const XlaOp> start_indices);
|
||||
|
||||
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
|
||||
virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
|
||||
|
|
Loading…
Reference in New Issue