Add verifier for xla_hlo.dot_general op.

PiperOrigin-RevId: 303179319
Change-Id: Ic506390b93656e0e8b71f985d6e1fc01ee1756c2
This commit is contained in:
Sean Silva 2020-03-26 13:11:47 -07:00 committed by TensorFlower Gardener
parent 6e920629c5
commit c5b4b6dc1f
3 changed files with 52 additions and 0 deletions

View File

@ -177,6 +177,31 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) {
result.addAttribute("value", value);
}
//===----------------------------------------------------------------------===//
// DotGeneralOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(DotGeneralOp op) {
auto dot_dimension_numbers = op.dot_dimension_numbers();
int64_t lhs_batching_dimensions_size = llvm::size(
dot_dimension_numbers.lhs_batching_dimensions().getValues<int64_t>());
int64_t rhs_batching_dimensions_size = llvm::size(
dot_dimension_numbers.rhs_batching_dimensions().getValues<int64_t>());
if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) {
return op.emitError()
<< "lhs and rhs should have the same number of batching dimensions";
}
int64_t lhs_contracting_dimensions_size = llvm::size(
dot_dimension_numbers.lhs_contracting_dimensions().getValues<int64_t>());
int64_t rhs_contracting_dimensions_size = llvm::size(
dot_dimension_numbers.rhs_contracting_dimensions().getValues<int64_t>());
if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) {
return op.emitError() << "lhs and rhs should have the same number of "
"contracting dimensions";
}
return success();
}
//===----------------------------------------------------------------------===//
// IotaOp
//===----------------------------------------------------------------------===//

View File

@ -949,6 +949,7 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral
);
let results = (outs HLO_Tensor);
let verifier = [{ return Verify(*this); }];
}
// Define Base Einsum op within the HLO dialect as these are client ops and

View File

@ -932,3 +932,29 @@ func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> {
%0 = "xla_hlo.reshape"(%operand) : (tensor<2x4xf32>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
}
// -----
func @dot_general(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) {
// expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}}
%0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = {
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return
}
// -----
func @dot_general(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) {
// expected-error @+1 {{lhs and rhs should have the same number of contracting dimensions}}
%0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = {
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
lhs_contracting_dimensions = dense<[]> : tensor<0xi64>,
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return
}