Implement test pass for checking behavior of InferShapedTypeOpInterface and add tests for CHLO.

* Exposed bugs in ops that have a different return vs operand element type.
* Existing logic is producing shape components that are more broad than necessary. Left a TODO.

PiperOrigin-RevId: 308338596
Change-Id: Iadde9da1052ec368f0dfef5a64c31712fe18f9a6
This commit is contained in:
Stella Laurenzo 2020-04-24 15:29:47 -07:00 committed by TensorFlower Gardener
parent d095009c33
commit f5d2eabea0
4 changed files with 207 additions and 8 deletions

View File

@ -511,6 +511,7 @@ cc_library(
srcs = [
"transforms/chlo_legalize_to_hlo_pass.cc",
"transforms/materialize_broadcasts_pass.cc",
"transforms/test_infer_shaped_type_pass.cc",
"transforms/unfuse_batch_norm_pass.cc",
],
deps = [
@ -519,6 +520,7 @@ cc_library(
":xla_materialize_broadcasts", # build-cleaner: keep
":xla_unfuse_batch_norm", # build-cleaner: keep
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",

View File

@ -96,7 +96,7 @@ static Type GetBroadcastType(Type x, Type y, Type element_type,
LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes,
ArrayRef<NamedAttribute> attributes, Type element_type,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
// Find broadcast_dimensions.
DenseIntElementsAttr broadcast_dimensions;
@ -113,7 +113,7 @@ LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
lhs_type.getElementType() != rhs_type.getElementType()) {
return emitOptionalError(location, "mismatched operand types");
}
Type element_type = lhs_type.getElementType();
if (!element_type) element_type = lhs_type.getElementType();
Type result_type =
GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions);
@ -159,16 +159,62 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
reifiedReturnShapes.push_back(computed_shape);
return success();
}
} // namespace
//===----------------------------------------------------------------------===//
// BroadcastComplexOp (has custom type inference due to different result type).
//===----------------------------------------------------------------------===//
LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
if (!lhs_type) {
return emitOptionalError(location, "expected ShapedType");
}
Type element_type = ComplexType::get(lhs_type.getElementType());
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
attributes, element_type,
inferedReturnShapes);
}
LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// BroadcastCompareOp (has custom type inference due to different result type).
//===----------------------------------------------------------------------===//
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
Type element_type = IntegerType::get(1, context);
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
attributes, element_type,
inferedReturnShapes);
}
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// Macros for method definitions that are common to most broadcasting ops.
//===----------------------------------------------------------------------===//
#define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \
LogicalResult Op::inferReturnTypeComponents( \
MLIRContext* context, Optional<Location> location, ValueRange operands, \
ArrayRef<NamedAttribute> attributes, RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \
return InferBroadcastBinaryOpReturnTypeComponents( \
context, location, operands, attributes, inferedReturnShapes); \
context, location, operands, attributes, /*element_type=*/nullptr, \
inferedReturnShapes); \
} \
LogicalResult Op::reifyReturnTypeShapes( \
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { \
@ -203,10 +249,6 @@ BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
// These only have the common shape inference defs but non-standard builders.
BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(BroadcastCompareOp);
BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(BroadcastComplexOp);
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
#undef BROADCAST_BINARY_OP_DEFS

View File

@ -0,0 +1,56 @@
// RUN: xla-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck --dump-input=fail %s
// CHECK-LABEL: @broadcast_add
// Note that all broadcast_ops are expanded from the same template, so
// only test reification on an examplar op.
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xf32>
func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xindex> {
// CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]])
// CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]])
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[EXTENTS:.+]] = "shape.to_extent_tensor"(%[[BCAST_S]])
// CHECK: return %[[EXTENTS]]
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
return %1 : tensor<1xindex>
}
// -----
// CHECK-LABEL: @complex_ranked_components
func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
return %1 : tensor<?x?xcomplex<f32>>
}
// -----
// CHECK-LABEL: @compare_ranked_components
func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
return %0 : tensor<?x?xi1>
}
// -----
// CHECK-LABEL: @broadcast_add_ranked_components_r1
func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
return %1 : tensor<?xf32>
}
// -----
// CHECK-LABEL: @broadcast_add_ranked_components_r1x2
func @broadcast_add_ranked_components_r1x2(%arg0: tensor<?xf32>, %arg1: tensor<?x3xf32>) -> tensor<?x3xf32> {
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
// TODO: Overly broad shapes are being returned. Tighten the calculation
// and update/extend these tests.
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x3xf32>) -> tensor<?x3xf32>
return %1 : tensor<?x3xf32>
}

View File

@ -0,0 +1,99 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {
namespace xla {
namespace {
struct InferReturnTypeComponentsPattern : public RewritePattern {
InferReturnTypeComponentsPattern(MLIRContext *context)
: RewritePattern("xla_test.get_return_type_components", 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1) return failure();
auto defining_op = op->getOperand(0).getDefiningOp();
auto defining_op_int =
llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(defining_op);
if (!defining_op_int) return failure();
SmallVector<ShapedTypeComponents, 4> components;
if (failed(defining_op_int.inferReturnTypeComponents(
op->getContext(), op->getLoc(), defining_op->getOperands(),
defining_op->getAttrs(), defining_op->getRegions(), components))) {
return failure();
}
// Replace the op with another pass-through op with attributes added.
OperationState state(op->getLoc(), "xla_test.return_type_components",
op->getOperands(), op->getResultTypes(),
op->getAttrs());
auto new_op = rewriter.createOperation(state);
for (auto it : llvm::enumerate(components)) {
if (it.value().hasRank()) {
new_op->setAttr((StringRef("dims") + Twine(it.index())).str(),
rewriter.getI64ArrayAttr(it.value().getDims()));
}
if (it.value().getElementType()) {
new_op->setAttr((Twine("element_type") + Twine(it.index())).str(),
TypeAttr::get(it.value().getElementType()));
}
}
rewriter.replaceOp(op, {new_op->getResults()});
return success();
}
};
struct ReifyReturnTypeShapesPattern : public RewritePattern {
ReifyReturnTypeShapesPattern(MLIRContext *context)
: RewritePattern("xla_test.reify_return_type_shapes", 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1) return failure();
auto defining_op = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
op->getOperand(0).getDefiningOp());
if (!defining_op) return failure();
SmallVector<Value, 4> return_shapes;
if (failed(defining_op.reifyReturnTypeShapes(rewriter, return_shapes))) {
return failure();
}
rewriter.replaceOp(op, return_shapes);
return success();
}
};
struct TestInferShapedTypeMethodsPass
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
patterns.insert<InferReturnTypeComponentsPattern>(&getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // namespace
} // namespace xla
} // namespace mlir
static mlir::PassRegistration<mlir::xla::TestInferShapedTypeMethodsPass> pass(
"test-xla-infer-shaped-type-methods",
"Uses test ops to invoke InferShapedTypeOpInterface methods");