From f5d2eabea0eeeec1e80814e8529a676c49c08132 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 24 Apr 2020 15:29:47 -0700 Subject: [PATCH] Implement test pass for checking behavior of InferShapedTypeOpInterface and add tests for CHLO. * Exposed bugs in ops that have a different return vs operand element type. * Existing logic is producing shape components that are more broad than necessary. Left a TODO. PiperOrigin-RevId: 308338596 Change-Id: Iadde9da1052ec368f0dfef5a64c31712fe18f9a6 --- tensorflow/compiler/mlir/xla/BUILD | 2 + tensorflow/compiler/mlir/xla/ir/chlo_ops.cc | 58 +++++++++-- .../tests/chlo_infer_shape_type_methods.mlir | 56 +++++++++++ .../transforms/test_infer_shaped_type_pass.cc | 99 +++++++++++++++++++ 4 files changed, 207 insertions(+), 8 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir create mode 100644 tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index a1a70718f27..abb1c338b42 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -511,6 +511,7 @@ cc_library( srcs = [ "transforms/chlo_legalize_to_hlo_pass.cc", "transforms/materialize_broadcasts_pass.cc", + "transforms/test_infer_shaped_type_pass.cc", "transforms/unfuse_batch_norm_pass.cc", ], deps = [ @@ -519,6 +520,7 @@ cc_library( ":xla_materialize_broadcasts", # build-cleaner: keep ":xla_unfuse_batch_norm", # build-cleaner: keep "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc index ba72cd17240..629fdc55ef0 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc @@ -96,7 +96,7 @@ static Type GetBroadcastType(Type x, Type y, Type element_type, LogicalResult InferBroadcastBinaryOpReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, + ArrayRef attributes, Type element_type, SmallVectorImpl& inferedReturnShapes) { // Find broadcast_dimensions. DenseIntElementsAttr broadcast_dimensions; @@ -113,7 +113,7 @@ LogicalResult InferBroadcastBinaryOpReturnTypeComponents( lhs_type.getElementType() != rhs_type.getElementType()) { return emitOptionalError(location, "mismatched operand types"); } - Type element_type = lhs_type.getElementType(); + if (!element_type) element_type = lhs_type.getElementType(); Type result_type = GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions); @@ -159,16 +159,62 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( reifiedReturnShapes.push_back(computed_shape); return success(); } - } // namespace +//===----------------------------------------------------------------------===// +// BroadcastComplexOp (has custom type inference due to different result type). +//===----------------------------------------------------------------------===// + +LogicalResult BroadcastComplexOp::inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + ShapedType lhs_type = operands[0].getType().dyn_cast(); + if (!lhs_type) { + return emitOptionalError(location, "expected ShapedType"); + } + Type element_type = ComplexType::get(lhs_type.getElementType()); + return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, + attributes, element_type, + inferedReturnShapes); +} +LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), + reifiedReturnShapes); +} + +//===----------------------------------------------------------------------===// +// BroadcastCompareOp (has custom type inference due to different result type). +//===----------------------------------------------------------------------===// + +LogicalResult BroadcastCompareOp::inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + Type element_type = IntegerType::get(1, context); + return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, + attributes, element_type, + inferedReturnShapes); +} +LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), + reifiedReturnShapes); +} + +//===----------------------------------------------------------------------===// +// Macros for method definitions that are common to most broadcasting ops. +//===----------------------------------------------------------------------===// + #define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \ LogicalResult Op::inferReturnTypeComponents( \ MLIRContext* context, Optional location, ValueRange operands, \ ArrayRef attributes, RegionRange regions, \ SmallVectorImpl& inferedReturnShapes) { \ return InferBroadcastBinaryOpReturnTypeComponents( \ - context, location, operands, attributes, inferedReturnShapes); \ + context, location, operands, attributes, /*element_type=*/nullptr, \ + inferedReturnShapes); \ } \ LogicalResult Op::reifyReturnTypeShapes( \ OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { \ @@ -203,10 +249,6 @@ BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp); BROADCAST_BINARY_OP_DEFS(BroadcastSubOp); BROADCAST_BINARY_OP_DEFS(BroadcastXorOp); -// These only have the common shape inference defs but non-standard builders. -BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(BroadcastCompareOp); -BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(BroadcastComplexOp); - #undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS #undef BROADCAST_BINARY_OP_DEFS diff --git a/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir new file mode 100644 index 00000000000..ce0243e416c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir @@ -0,0 +1,56 @@ +// RUN: xla-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck --dump-input=fail %s + +// CHECK-LABEL: @broadcast_add +// Note that all broadcast_ops are expanded from the same template, so +// only test reification on an examplar op. +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor +func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xindex> { + // CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]]) + // CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]]) + // CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK: %[[EXTENTS:.+]] = "shape.to_extent_tensor"(%[[BCAST_S]]) + // CHECK: return %[[EXTENTS]] + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + %1 = "xla_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> + return %1 : tensor<1xindex> +} + +// ----- +// CHECK-LABEL: @complex_ranked_components +func @complex_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor> { + %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex} + %1 = "xla_test.get_return_type_components"(%0) : (tensor>) -> tensor> + return %1 : tensor> +} + +// ----- +// CHECK-LABEL: @compare_ranked_components +func @compare_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1} + %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + return %0 : tensor +} + +// ----- +// CHECK-LABEL: @broadcast_add_ranked_components_r1 +func @broadcast_add_ranked_components_r1(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32} + %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// ----- +// CHECK-LABEL: @broadcast_add_ranked_components_r1x2 +func @broadcast_add_ranked_components_r1x2(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + // TODO: Overly broad shapes are being returned. Tighten the calculation + // and update/extend these tests. + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32} + %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + return %1 : tensor +} + diff --git a/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc b/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc new file mode 100644 index 00000000000..8976bd5b7d2 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace xla { +namespace { + +struct InferReturnTypeComponentsPattern : public RewritePattern { + InferReturnTypeComponentsPattern(MLIRContext *context) + : RewritePattern("xla_test.get_return_type_components", 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) return failure(); + auto defining_op = op->getOperand(0).getDefiningOp(); + auto defining_op_int = + llvm::dyn_cast_or_null(defining_op); + if (!defining_op_int) return failure(); + SmallVector components; + if (failed(defining_op_int.inferReturnTypeComponents( + op->getContext(), op->getLoc(), defining_op->getOperands(), + defining_op->getAttrs(), defining_op->getRegions(), components))) { + return failure(); + } + + // Replace the op with another pass-through op with attributes added. + OperationState state(op->getLoc(), "xla_test.return_type_components", + op->getOperands(), op->getResultTypes(), + op->getAttrs()); + auto new_op = rewriter.createOperation(state); + for (auto it : llvm::enumerate(components)) { + if (it.value().hasRank()) { + new_op->setAttr((StringRef("dims") + Twine(it.index())).str(), + rewriter.getI64ArrayAttr(it.value().getDims())); + } + if (it.value().getElementType()) { + new_op->setAttr((Twine("element_type") + Twine(it.index())).str(), + TypeAttr::get(it.value().getElementType())); + } + } + rewriter.replaceOp(op, {new_op->getResults()}); + return success(); + } +}; + +struct ReifyReturnTypeShapesPattern : public RewritePattern { + ReifyReturnTypeShapesPattern(MLIRContext *context) + : RewritePattern("xla_test.reify_return_type_shapes", 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) return failure(); + auto defining_op = llvm::dyn_cast_or_null( + op->getOperand(0).getDefiningOp()); + if (!defining_op) return failure(); + SmallVector return_shapes; + if (failed(defining_op.reifyReturnTypeShapes(rewriter, return_shapes))) { + return failure(); + } + rewriter.replaceOp(op, return_shapes); + return success(); + } +}; + +struct TestInferShapedTypeMethodsPass + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + patterns.insert(&getContext()); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // namespace +} // namespace xla +} // namespace mlir + +static mlir::PassRegistration pass( + "test-xla-infer-shaped-type-methods", + "Uses test ops to invoke InferShapedTypeOpInterface methods");