Build DynamicSlice and DynamicUpdateSlice ops with MlirHloBuilder

Whitelist XlaDynamicSlice and XlaDynamicUpdateSlice for testing

PiperOrigin-RevId: 311642899
Change-Id: Icbf009cf69d3b183d0c83c10925a5fbaa3c49f1f
This commit is contained in:
Smit Hinsu 2020-05-14 18:00:36 -07:00 committed by TensorFlower Gardener
parent a2ef8b5a06
commit 4662933489
7 changed files with 93 additions and 22 deletions

View File

@ -282,6 +282,28 @@ StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
GetI64ElementsAttr(strides, &builder_))); GetI64ElementsAttr(strides, &builder_)));
} }
StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicSliceOp>(
loc_, result_ty, GetValue(operand), GetValues(start_indices),
GetI64ElementsAttr(slice_sizes, &builder_)));
}
StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
const Shape& shape, XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicUpdateSliceOp>(
loc_, result_ty, GetValue(operand), GetValue(update),
GetValues(start_indices)));
}
StatusOr<XlaOp> MlirHloBuilder::PadInternal( StatusOr<XlaOp> MlirHloBuilder::PadInternal(
const Shape& shape, XlaOp operand, XlaOp padding_value, const Shape& shape, XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config) { const PaddingConfig& padding_config) {

View File

@ -175,6 +175,14 @@ class MlirHloBuilder : public XlaBuilder {
absl::Span<const int64> limit_indices, absl::Span<const int64> limit_indices,
absl::Span<const int64> strides) override; absl::Span<const int64> strides) override;
StatusOr<XlaOp> DynamicSliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes) override;
StatusOr<XlaOp> DynamicUpdateSliceInternal(
const Shape& shape, XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) override;
StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand, StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
XlaOp padding_value, XlaOp padding_value,
const PaddingConfig& padding_config) override; const PaddingConfig& padding_config) override;

View File

@ -163,6 +163,30 @@ func @truncated_normal() -> tensor<2x2xf32> {
return %1 : tensor<2x2xf32> return %1 : tensor<2x2xf32>
} }
// CHECK-LABEL: dynamic_update_slice
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32>
func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> {
// CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]])
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>
// CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
// CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
// CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]])
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
// CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor<i32>
// CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]])
%0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
return %0: tensor<3x4xi32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance. // available but doesn't support this instance.
} }

View File

@ -168,6 +168,8 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::XlaBroadcastHelperOp>(), TypeID::get<TF::XlaBroadcastHelperOp>(),
TypeID::get<TF::XlaConvOp>(), TypeID::get<TF::XlaConvOp>(),
TypeID::get<TF::XlaDotOp>(), TypeID::get<TF::XlaDotOp>(),
TypeID::get<TF::XlaDynamicSliceOp>(),
TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
TypeID::get<TF::XlaPadOp>(), TypeID::get<TF::XlaPadOp>(),
TypeID::get<TF::Xlog1pyOp>(), TypeID::get<TF::Xlog1pyOp>(),
TypeID::get<TF::XlogyOp>() TypeID::get<TF::XlogyOp>()

View File

@ -304,7 +304,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
@test_util.disable_mlir_bridge('Not supported yet')
def testDynamicSlice(self): def testDynamicSlice(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
@ -317,7 +316,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
[[673, 674], [683, 684], [693, 694]]]), [[673, 674], [683, 684], [693, 694]]]),
dtype=dtype)) dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet') @test_util.disable_mlir_bridge('Error handling')
def testDynamicSliceWithIncorrectStartIndicesShape(self): def testDynamicSliceWithIncorrectStartIndicesShape(self):
with self.session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
@ -331,7 +330,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
(r'start_indices must be a vector with length equal to input rank, ' (r'start_indices must be a vector with length equal to input rank, '
r'but input rank is 3 and start_indices has shape \[2\].*')) r'but input rank is 3 and start_indices has shape \[2\].*'))
@test_util.disable_mlir_bridge('Not supported yet') @test_util.disable_mlir_bridge('Error handling')
def testDynamicSliceWithIncorrectSizeIndicesShape(self): def testDynamicSliceWithIncorrectSizeIndicesShape(self):
with self.session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():

View File

@ -864,8 +864,6 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
absl::Span<const XlaOp> start_indices, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes) { absl::Span<const int64> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
std::vector<const Shape*> start_indices_shape_ptrs; std::vector<const Shape*> start_indices_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
@ -876,6 +874,14 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
TF_ASSIGN_OR_RETURN(Shape shape, TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferDynamicSliceShape( ShapeInference::InferDynamicSliceShape(
*operand_shape, start_indices_shapes, slice_sizes)); *operand_shape, start_indices_shapes, slice_sizes));
return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
});
}
StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto(); *instr.mutable_shape() = shape.ToProto();
for (int64 size : slice_sizes) { for (int64 size : slice_sizes) {
@ -885,14 +891,11 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
std::vector<XlaOp> operands = {operand}; std::vector<XlaOp> operands = {operand};
operands.insert(operands.end(), start_indices.begin(), start_indices.end()); operands.insert(operands.end(), start_indices.begin(), start_indices.end());
return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
});
} }
XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) { absl::Span<const XlaOp> start_indices) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update)); TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
std::vector<const Shape*> start_indices_shape_ptrs; std::vector<const Shape*> start_indices_shape_ptrs;
@ -904,13 +907,20 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferDynamicUpdateSliceShape( Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
*operand_shape, *update_shape, start_indices_shapes)); *operand_shape, *update_shape, start_indices_shapes));
return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
});
}
StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
const Shape& shape, XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto(); *instr.mutable_shape() = shape.ToProto();
std::vector<XlaOp> operands = {operand, update}; std::vector<XlaOp> operands = {operand, update};
operands.insert(operands.end(), start_indices.begin(), start_indices.end()); operands.insert(operands.end(), start_indices.begin(), start_indices.end());
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
operands); operands);
});
} }
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands, XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,

View File

@ -423,9 +423,15 @@ class XlaBuilder {
XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices, XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes); absl::Span<const int64> slice_sizes);
virtual StatusOr<XlaOp> DynamicSliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes);
XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices); absl::Span<const XlaOp> start_indices);
virtual StatusOr<XlaOp> DynamicUpdateSliceInternal(
const Shape& shape, XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices);
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension); XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape, virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,