diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 0ce90f0a445..86e865a1657 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -177,6 +177,31 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) { result.addAttribute("value", value); } +//===----------------------------------------------------------------------===// +// DotGeneralOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DotGeneralOp op) { + auto dot_dimension_numbers = op.dot_dimension_numbers(); + int64_t lhs_batching_dimensions_size = llvm::size( + dot_dimension_numbers.lhs_batching_dimensions().getValues()); + int64_t rhs_batching_dimensions_size = llvm::size( + dot_dimension_numbers.rhs_batching_dimensions().getValues()); + if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) { + return op.emitError() + << "lhs and rhs should have the same number of batching dimensions"; + } + int64_t lhs_contracting_dimensions_size = llvm::size( + dot_dimension_numbers.lhs_contracting_dimensions().getValues()); + int64_t rhs_contracting_dimensions_size = llvm::size( + dot_dimension_numbers.rhs_contracting_dimensions().getValues()); + if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) { + return op.emitError() << "lhs and rhs should have the same number of " + "contracting dimensions"; + } + return success(); +} + //===----------------------------------------------------------------------===// // IotaOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 1c38d3ae3e1..00b43198c55 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -949,6 +949,7 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral ); let results = (outs HLO_Tensor); + let verifier = [{ return Verify(*this); }]; } // Define Base Einsum op within the HLO dialect as these are client ops and diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 037eded9ba6..a1cddab54c9 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -932,3 +932,29 @@ func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> { %0 = "xla_hlo.reshape"(%operand) : (tensor<2x4xf32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> } + +// ----- + +func @dot_general(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}} + %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<0> : tensor<1xi64>, + lhs_contracting_dimensions = dense<2> : tensor<1xi64>, + rhs_batching_dimensions = dense<[]> : tensor<0xi64>, + rhs_contracting_dimensions = dense<1> : tensor<1xi64> + }} : (tensor, tensor) -> tensor + return +} + +// ----- + +func @dot_general(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{lhs and rhs should have the same number of contracting dimensions}} + %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<0> : tensor<1xi64>, + lhs_contracting_dimensions = dense<[]> : tensor<0xi64>, + rhs_batching_dimensions = dense<0> : tensor<1xi64>, + rhs_contracting_dimensions = dense<1> : tensor<1xi64> + }} : (tensor, tensor) -> tensor + return +}