Implement test pass for checking behavior of InferShapedTypeOpInterface and add tests for CHLO.
* Exposed bugs in ops that have a different return vs operand element type. * Existing logic is producing shape components that are more broad than necessary. Left a TODO. PiperOrigin-RevId: 308338596 Change-Id: Iadde9da1052ec368f0dfef5a64c31712fe18f9a6
This commit is contained in:
parent
d095009c33
commit
f5d2eabea0
@ -511,6 +511,7 @@ cc_library(
|
||||
srcs = [
|
||||
"transforms/chlo_legalize_to_hlo_pass.cc",
|
||||
"transforms/materialize_broadcasts_pass.cc",
|
||||
"transforms/test_infer_shaped_type_pass.cc",
|
||||
"transforms/unfuse_batch_norm_pass.cc",
|
||||
],
|
||||
deps = [
|
||||
@ -519,6 +520,7 @@ cc_library(
|
||||
":xla_materialize_broadcasts", # build-cleaner: keep
|
||||
":xla_unfuse_batch_norm", # build-cleaner: keep
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
|
@ -96,7 +96,7 @@ static Type GetBroadcastType(Type x, Type y, Type element_type,
|
||||
|
||||
LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
ArrayRef<NamedAttribute> attributes, Type element_type,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||
// Find broadcast_dimensions.
|
||||
DenseIntElementsAttr broadcast_dimensions;
|
||||
@ -113,7 +113,7 @@ LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
|
||||
lhs_type.getElementType() != rhs_type.getElementType()) {
|
||||
return emitOptionalError(location, "mismatched operand types");
|
||||
}
|
||||
Type element_type = lhs_type.getElementType();
|
||||
if (!element_type) element_type = lhs_type.getElementType();
|
||||
Type result_type =
|
||||
GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions);
|
||||
|
||||
@ -159,16 +159,62 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
|
||||
reifiedReturnShapes.push_back(computed_shape);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BroadcastComplexOp (has custom type inference due to different result type).
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||
ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
|
||||
if (!lhs_type) {
|
||||
return emitOptionalError(location, "expected ShapedType");
|
||||
}
|
||||
Type element_type = ComplexType::get(lhs_type.getElementType());
|
||||
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
|
||||
attributes, element_type,
|
||||
inferedReturnShapes);
|
||||
}
|
||||
LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
||||
reifiedReturnShapes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BroadcastCompareOp (has custom type inference due to different result type).
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||
Type element_type = IntegerType::get(1, context);
|
||||
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
|
||||
attributes, element_type,
|
||||
inferedReturnShapes);
|
||||
}
|
||||
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
||||
reifiedReturnShapes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Macros for method definitions that are common to most broadcasting ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \
|
||||
LogicalResult Op::inferReturnTypeComponents( \
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands, \
|
||||
ArrayRef<NamedAttribute> attributes, RegionRange regions, \
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \
|
||||
return InferBroadcastBinaryOpReturnTypeComponents( \
|
||||
context, location, operands, attributes, inferedReturnShapes); \
|
||||
context, location, operands, attributes, /*element_type=*/nullptr, \
|
||||
inferedReturnShapes); \
|
||||
} \
|
||||
LogicalResult Op::reifyReturnTypeShapes( \
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { \
|
||||
@ -203,10 +249,6 @@ BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
|
||||
|
||||
// These only have the common shape inference defs but non-standard builders.
|
||||
BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(BroadcastCompareOp);
|
||||
BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(BroadcastComplexOp);
|
||||
|
||||
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
|
||||
#undef BROADCAST_BINARY_OP_DEFS
|
||||
|
||||
|
@ -0,0 +1,56 @@
|
||||
// RUN: xla-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck --dump-input=fail %s
|
||||
|
||||
// CHECK-LABEL: @broadcast_add
|
||||
// Note that all broadcast_ops are expanded from the same template, so
|
||||
// only test reification on an examplar op.
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>,
|
||||
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xf32>
|
||||
func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xindex> {
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]])
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]])
|
||||
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[EXTENTS:.+]] = "shape.to_extent_tensor"(%[[BCAST_S]])
|
||||
// CHECK: return %[[EXTENTS]]
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
|
||||
return %1 : tensor<1xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @complex_ranked_components
|
||||
func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
|
||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
|
||||
return %1 : tensor<?x?xcomplex<f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @compare_ranked_components
|
||||
func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
|
||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
return %0 : tensor<?x?xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @broadcast_add_ranked_components_r1
|
||||
func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @broadcast_add_ranked_components_r1x2
|
||||
func @broadcast_add_ranked_components_r1x2(%arg0: tensor<?xf32>, %arg1: tensor<?x3xf32>) -> tensor<?x3xf32> {
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
|
||||
// TODO: Overly broad shapes are being returned. Tighten the calculation
|
||||
// and update/extend these tests.
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x3xf32>) -> tensor<?x3xf32>
|
||||
return %1 : tensor<?x3xf32>
|
||||
}
|
||||
|
@ -0,0 +1,99 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
struct InferReturnTypeComponentsPattern : public RewritePattern {
|
||||
InferReturnTypeComponentsPattern(MLIRContext *context)
|
||||
: RewritePattern("xla_test.get_return_type_components", 1, context) {}
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->getNumOperands() != 1) return failure();
|
||||
auto defining_op = op->getOperand(0).getDefiningOp();
|
||||
auto defining_op_int =
|
||||
llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(defining_op);
|
||||
if (!defining_op_int) return failure();
|
||||
SmallVector<ShapedTypeComponents, 4> components;
|
||||
if (failed(defining_op_int.inferReturnTypeComponents(
|
||||
op->getContext(), op->getLoc(), defining_op->getOperands(),
|
||||
defining_op->getAttrs(), defining_op->getRegions(), components))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Replace the op with another pass-through op with attributes added.
|
||||
OperationState state(op->getLoc(), "xla_test.return_type_components",
|
||||
op->getOperands(), op->getResultTypes(),
|
||||
op->getAttrs());
|
||||
auto new_op = rewriter.createOperation(state);
|
||||
for (auto it : llvm::enumerate(components)) {
|
||||
if (it.value().hasRank()) {
|
||||
new_op->setAttr((StringRef("dims") + Twine(it.index())).str(),
|
||||
rewriter.getI64ArrayAttr(it.value().getDims()));
|
||||
}
|
||||
if (it.value().getElementType()) {
|
||||
new_op->setAttr((Twine("element_type") + Twine(it.index())).str(),
|
||||
TypeAttr::get(it.value().getElementType()));
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, {new_op->getResults()});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReifyReturnTypeShapesPattern : public RewritePattern {
|
||||
ReifyReturnTypeShapesPattern(MLIRContext *context)
|
||||
: RewritePattern("xla_test.reify_return_type_shapes", 1, context) {}
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->getNumOperands() != 1) return failure();
|
||||
auto defining_op = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
|
||||
op->getOperand(0).getDefiningOp());
|
||||
if (!defining_op) return failure();
|
||||
SmallVector<Value, 4> return_shapes;
|
||||
if (failed(defining_op.reifyReturnTypeShapes(rewriter, return_shapes))) {
|
||||
return failure();
|
||||
}
|
||||
rewriter.replaceOp(op, return_shapes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TestInferShapedTypeMethodsPass
|
||||
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
|
||||
patterns.insert<InferReturnTypeComponentsPattern>(&getContext());
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::xla::TestInferShapedTypeMethodsPass> pass(
|
||||
"test-xla-infer-shaped-type-methods",
|
||||
"Uses test ops to invoke InferShapedTypeOpInterface methods");
|
Loading…
x
Reference in New Issue
Block a user