Build DynamicSlice and DynamicUpdateSlice ops with MlirHloBuilder
Whitelist XlaDynamicSlice and XlaDynamicUpdateSlice for testing PiperOrigin-RevId: 311642899 Change-Id: Icbf009cf69d3b183d0c83c10925a5fbaa3c49f1f
This commit is contained in:
parent
a2ef8b5a06
commit
4662933489
|
@ -282,6 +282,28 @@ StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
|
||||||
GetI64ElementsAttr(strides, &builder_)));
|
GetI64ElementsAttr(strides, &builder_)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
|
||||||
|
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
|
||||||
|
absl::Span<const int64> slice_sizes) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
mlir::Type result_ty,
|
||||||
|
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||||
|
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicSliceOp>(
|
||||||
|
loc_, result_ty, GetValue(operand), GetValues(start_indices),
|
||||||
|
GetI64ElementsAttr(slice_sizes, &builder_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
|
||||||
|
const Shape& shape, XlaOp operand, XlaOp update,
|
||||||
|
absl::Span<const XlaOp> start_indices) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
mlir::Type result_ty,
|
||||||
|
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||||
|
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicUpdateSliceOp>(
|
||||||
|
loc_, result_ty, GetValue(operand), GetValue(update),
|
||||||
|
GetValues(start_indices)));
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<XlaOp> MlirHloBuilder::PadInternal(
|
StatusOr<XlaOp> MlirHloBuilder::PadInternal(
|
||||||
const Shape& shape, XlaOp operand, XlaOp padding_value,
|
const Shape& shape, XlaOp operand, XlaOp padding_value,
|
||||||
const PaddingConfig& padding_config) {
|
const PaddingConfig& padding_config) {
|
||||||
|
|
|
@ -175,6 +175,14 @@ class MlirHloBuilder : public XlaBuilder {
|
||||||
absl::Span<const int64> limit_indices,
|
absl::Span<const int64> limit_indices,
|
||||||
absl::Span<const int64> strides) override;
|
absl::Span<const int64> strides) override;
|
||||||
|
|
||||||
|
StatusOr<XlaOp> DynamicSliceInternal(
|
||||||
|
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
|
||||||
|
absl::Span<const int64> slice_sizes) override;
|
||||||
|
|
||||||
|
StatusOr<XlaOp> DynamicUpdateSliceInternal(
|
||||||
|
const Shape& shape, XlaOp operand, XlaOp update,
|
||||||
|
absl::Span<const XlaOp> start_indices) override;
|
||||||
|
|
||||||
StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
|
StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
|
||||||
XlaOp padding_value,
|
XlaOp padding_value,
|
||||||
const PaddingConfig& padding_config) override;
|
const PaddingConfig& padding_config) override;
|
||||||
|
|
|
@ -163,6 +163,30 @@ func @truncated_normal() -> tensor<2x2xf32> {
|
||||||
return %1 : tensor<2x2xf32>
|
return %1 : tensor<2x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: dynamic_update_slice
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32>
|
||||||
|
func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> {
|
||||||
|
|
||||||
|
// CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]])
|
||||||
|
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>
|
||||||
|
// CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64>
|
||||||
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
|
||||||
|
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
|
||||||
|
// CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
|
||||||
|
|
||||||
|
// CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]])
|
||||||
|
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
|
||||||
|
// CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64>
|
||||||
|
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
|
||||||
|
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
|
||||||
|
// CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor<i32>
|
||||||
|
|
||||||
|
// CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]])
|
||||||
|
|
||||||
|
%0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
|
||||||
|
return %0: tensor<3x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||||
// available but doesn't support this instance.
|
// available but doesn't support this instance.
|
||||||
}
|
}
|
||||||
|
|
|
@ -168,6 +168,8 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||||
TypeID::get<TF::XlaBroadcastHelperOp>(),
|
TypeID::get<TF::XlaBroadcastHelperOp>(),
|
||||||
TypeID::get<TF::XlaConvOp>(),
|
TypeID::get<TF::XlaConvOp>(),
|
||||||
TypeID::get<TF::XlaDotOp>(),
|
TypeID::get<TF::XlaDotOp>(),
|
||||||
|
TypeID::get<TF::XlaDynamicSliceOp>(),
|
||||||
|
TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
|
||||||
TypeID::get<TF::XlaPadOp>(),
|
TypeID::get<TF::XlaPadOp>(),
|
||||||
TypeID::get<TF::Xlog1pyOp>(),
|
TypeID::get<TF::Xlog1pyOp>(),
|
||||||
TypeID::get<TF::XlogyOp>()
|
TypeID::get<TF::XlogyOp>()
|
||||||
|
|
|
@ -304,7 +304,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
|
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge('Not supported yet')
|
|
||||||
def testDynamicSlice(self):
|
def testDynamicSlice(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
|
@ -317,7 +316,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
[[673, 674], [683, 684], [693, 694]]]),
|
[[673, 674], [683, 684], [693, 694]]]),
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge('Not supported yet')
|
@test_util.disable_mlir_bridge('Error handling')
|
||||||
def testDynamicSliceWithIncorrectStartIndicesShape(self):
|
def testDynamicSliceWithIncorrectStartIndicesShape(self):
|
||||||
with self.session() as session:
|
with self.session() as session:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
|
@ -331,7 +330,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
(r'start_indices must be a vector with length equal to input rank, '
|
(r'start_indices must be a vector with length equal to input rank, '
|
||||||
r'but input rank is 3 and start_indices has shape \[2\].*'))
|
r'but input rank is 3 and start_indices has shape \[2\].*'))
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge('Not supported yet')
|
@test_util.disable_mlir_bridge('Error handling')
|
||||||
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
|
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
|
||||||
with self.session() as session:
|
with self.session() as session:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
|
|
|
@ -864,8 +864,6 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
|
||||||
absl::Span<const XlaOp> start_indices,
|
absl::Span<const XlaOp> start_indices,
|
||||||
absl::Span<const int64> slice_sizes) {
|
absl::Span<const int64> slice_sizes) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
std::vector<const Shape*> start_indices_shape_ptrs;
|
std::vector<const Shape*> start_indices_shape_ptrs;
|
||||||
TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
|
TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
|
||||||
|
@ -876,6 +874,14 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
|
||||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||||
ShapeInference::InferDynamicSliceShape(
|
ShapeInference::InferDynamicSliceShape(
|
||||||
*operand_shape, start_indices_shapes, slice_sizes));
|
*operand_shape, start_indices_shapes, slice_sizes));
|
||||||
|
return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
|
||||||
|
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
|
||||||
|
absl::Span<const int64> slice_sizes) {
|
||||||
|
HloInstructionProto instr;
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
|
|
||||||
for (int64 size : slice_sizes) {
|
for (int64 size : slice_sizes) {
|
||||||
|
@ -885,14 +891,11 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
|
||||||
std::vector<XlaOp> operands = {operand};
|
std::vector<XlaOp> operands = {operand};
|
||||||
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
|
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
|
return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
|
XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
|
||||||
absl::Span<const XlaOp> start_indices) {
|
absl::Span<const XlaOp> start_indices) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
|
TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
|
||||||
std::vector<const Shape*> start_indices_shape_ptrs;
|
std::vector<const Shape*> start_indices_shape_ptrs;
|
||||||
|
@ -904,13 +907,20 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
|
Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
|
||||||
*operand_shape, *update_shape, start_indices_shapes));
|
*operand_shape, *update_shape, start_indices_shapes));
|
||||||
|
return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
|
||||||
|
const Shape& shape, XlaOp operand, XlaOp update,
|
||||||
|
absl::Span<const XlaOp> start_indices) {
|
||||||
|
HloInstructionProto instr;
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
|
|
||||||
std::vector<XlaOp> operands = {operand, update};
|
std::vector<XlaOp> operands = {operand, update};
|
||||||
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
|
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
|
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
|
||||||
operands);
|
operands);
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
|
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
|
||||||
|
|
|
@ -423,9 +423,15 @@ class XlaBuilder {
|
||||||
|
|
||||||
XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
|
XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
|
||||||
absl::Span<const int64> slice_sizes);
|
absl::Span<const int64> slice_sizes);
|
||||||
|
virtual StatusOr<XlaOp> DynamicSliceInternal(
|
||||||
|
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
|
||||||
|
absl::Span<const int64> slice_sizes);
|
||||||
|
|
||||||
XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
|
XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
|
||||||
absl::Span<const XlaOp> start_indices);
|
absl::Span<const XlaOp> start_indices);
|
||||||
|
virtual StatusOr<XlaOp> DynamicUpdateSliceInternal(
|
||||||
|
const Shape& shape, XlaOp operand, XlaOp update,
|
||||||
|
absl::Span<const XlaOp> start_indices);
|
||||||
|
|
||||||
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
|
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
|
||||||
virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
|
virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
|
||||||
|
|
Loading…
Reference in New Issue