Add verifier for xla_hlo.dot_general op.
PiperOrigin-RevId: 303179319 Change-Id: Ic506390b93656e0e8b71f985d6e1fc01ee1756c2
This commit is contained in:
parent
6e920629c5
commit
c5b4b6dc1f
@ -177,6 +177,31 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) {
|
||||
result.addAttribute("value", value);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DotGeneralOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(DotGeneralOp op) {
|
||||
auto dot_dimension_numbers = op.dot_dimension_numbers();
|
||||
int64_t lhs_batching_dimensions_size = llvm::size(
|
||||
dot_dimension_numbers.lhs_batching_dimensions().getValues<int64_t>());
|
||||
int64_t rhs_batching_dimensions_size = llvm::size(
|
||||
dot_dimension_numbers.rhs_batching_dimensions().getValues<int64_t>());
|
||||
if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) {
|
||||
return op.emitError()
|
||||
<< "lhs and rhs should have the same number of batching dimensions";
|
||||
}
|
||||
int64_t lhs_contracting_dimensions_size = llvm::size(
|
||||
dot_dimension_numbers.lhs_contracting_dimensions().getValues<int64_t>());
|
||||
int64_t rhs_contracting_dimensions_size = llvm::size(
|
||||
dot_dimension_numbers.rhs_contracting_dimensions().getValues<int64_t>());
|
||||
if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) {
|
||||
return op.emitError() << "lhs and rhs should have the same number of "
|
||||
"contracting dimensions";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IotaOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -949,6 +949,7 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral
|
||||
);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
// Define Base Einsum op within the HLO dialect as these are client ops and
|
||||
|
@ -932,3 +932,29 @@ func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> {
|
||||
%0 = "xla_hlo.reshape"(%operand) : (tensor<2x4xf32>) -> tensor<3x3xf32>
|
||||
return %0 : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_general(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) {
|
||||
// expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}}
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = {
|
||||
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_general(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) {
|
||||
// expected-error @+1 {{lhs and rhs should have the same number of contracting dimensions}}
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = {
|
||||
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
lhs_contracting_dimensions = dense<[]> : tensor<0xi64>,
|
||||
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user