Build DynamicSlice and DynamicUpdateSlice ops with MlirHloBuilder

Whitelist XlaDynamicSlice and XlaDynamicUpdateSlice for testing

PiperOrigin-RevId: 311642899
Change-Id: Icbf009cf69d3b183d0c83c10925a5fbaa3c49f1f
This commit is contained in:
Smit Hinsu 2020-05-14 18:00:36 -07:00 committed by TensorFlower Gardener
parent a2ef8b5a06
commit 4662933489
7 changed files with 93 additions and 22 deletions

View File

@ -282,6 +282,28 @@ StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
GetI64ElementsAttr(strides, &builder_)));
}
StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicSliceOp>(
loc_, result_ty, GetValue(operand), GetValues(start_indices),
GetI64ElementsAttr(slice_sizes, &builder_)));
}
StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
const Shape& shape, XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicUpdateSliceOp>(
loc_, result_ty, GetValue(operand), GetValue(update),
GetValues(start_indices)));
}
StatusOr<XlaOp> MlirHloBuilder::PadInternal(
const Shape& shape, XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config) {

View File

@ -175,6 +175,14 @@ class MlirHloBuilder : public XlaBuilder {
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides) override;
StatusOr<XlaOp> DynamicSliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes) override;
StatusOr<XlaOp> DynamicUpdateSliceInternal(
const Shape& shape, XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) override;
StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
XlaOp padding_value,
const PaddingConfig& padding_config) override;

View File

@ -163,6 +163,30 @@ func @truncated_normal() -> tensor<2x2xf32> {
return %1 : tensor<2x2xf32>
}
// CHECK-LABEL: dynamic_update_slice
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32>
func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> {
// CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]])
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>
// CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
// CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
// CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]])
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
// CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor<i32>
// CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]])
%0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
return %0: tensor<3x4xi32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.
}

View File

@ -168,6 +168,8 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::XlaBroadcastHelperOp>(),
TypeID::get<TF::XlaConvOp>(),
TypeID::get<TF::XlaDotOp>(),
TypeID::get<TF::XlaDynamicSliceOp>(),
TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
TypeID::get<TF::XlaPadOp>(),
TypeID::get<TF::Xlog1pyOp>(),
TypeID::get<TF::XlogyOp>()

View File

@ -304,7 +304,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
self._assertOpOutputMatchesExpected(
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
@test_util.disable_mlir_bridge('Not supported yet')
def testDynamicSlice(self):
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
@ -317,7 +316,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
[[673, 674], [683, 684], [693, 694]]]),
dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet')
@test_util.disable_mlir_bridge('Error handling')
def testDynamicSliceWithIncorrectStartIndicesShape(self):
with self.session() as session:
with self.test_scope():
@ -331,7 +330,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
(r'start_indices must be a vector with length equal to input rank, '
r'but input rank is 3 and start_indices has shape \[2\].*'))
@test_util.disable_mlir_bridge('Not supported yet')
@test_util.disable_mlir_bridge('Error handling')
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
with self.session() as session:
with self.test_scope():

View File

@ -864,8 +864,6 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
std::vector<const Shape*> start_indices_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
@ -876,23 +874,28 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferDynamicSliceShape(
*operand_shape, start_indices_shapes, slice_sizes));
*instr.mutable_shape() = shape.ToProto();
for (int64 size : slice_sizes) {
instr.add_dynamic_slice_sizes(size);
}
std::vector<XlaOp> operands = {operand};
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
});
}
StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int64 size : slice_sizes) {
instr.add_dynamic_slice_sizes(size);
}
std::vector<XlaOp> operands = {operand};
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
}
XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
std::vector<const Shape*> start_indices_shape_ptrs;
@ -904,15 +907,22 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
*operand_shape, *update_shape, start_indices_shapes));
*instr.mutable_shape() = shape.ToProto();
std::vector<XlaOp> operands = {operand, update};
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
operands);
return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
});
}
StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
const Shape& shape, XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
std::vector<XlaOp> operands = {operand, update};
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
operands);
}
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {

View File

@ -423,9 +423,15 @@ class XlaBuilder {
XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes);
virtual StatusOr<XlaOp> DynamicSliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes);
XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices);
virtual StatusOr<XlaOp> DynamicUpdateSliceInternal(
const Shape& shape, XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices);
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,