From b6d39383aa1a95f377270dea7900b72cd2f98ada Mon Sep 17 00:00:00 2001 From: Yujing Zhang Date: Thu, 18 Feb 2021 10:59:55 -0800 Subject: [PATCH] Added pass xla-legalize-tf-types. This pass converts quantized types to non-quantized (e.g. qint8 to i8). PiperOrigin-RevId: 358217198 Change-Id: I0815555f5d0f7c00121825fd4bc75e5016b3ce17 --- .../tensorflow/utils/compile_mlir_util.cc | 1 - tensorflow/compiler/mlir/xla/BUILD | 4 - .../mlir/xla/tests/legalize-tf-types.mlir | 54 ----- .../mlir/xla/transforms/legalize_tf_types.cc | 185 ------------------ .../compiler/mlir/xla/transforms/passes.h | 4 - .../mlir/xla/transforms/xla_passes.td | 18 -- 6 files changed, 266 deletions(-) delete mode 100644 tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir delete mode 100644 tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index f10aca20b47..59c647e598e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -307,7 +307,6 @@ void CreateConvertMlirToXlaHloPipeline( // inside PromoteResourcesToArgs. pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); - pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass()); pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( /*allow_partial_conversion=*/true, /*legalize_chlo=*/true, /*tf2xla_fallback_device_type=*/device_type)); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 63be2fa8d60..b2d1e15b53c 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -72,7 +72,6 @@ gentbl( cc_library( name = "xla_passes", srcs = [ - "transforms/legalize_tf_types.cc", "transforms/passes_detail.h", "transforms/prepare_for_export.cc", ], @@ -82,12 +81,10 @@ cc_library( deps = [ ":xla_passes_inc_gen", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, ) @@ -112,7 +109,6 @@ cc_library( "//tensorflow/compiler/mlir/hlo:convert_op_folder", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:padding", diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir deleted file mode 100644 index 56d903be892..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir +++ /dev/null @@ -1,54 +0,0 @@ -// RUN: tf-opt "-xla-legalize-tf-types" %s | FILECHECK_OPTS="" FileCheck %s - -func @relu_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> { - // CHECK: func @relu_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> { - // CHECK-NEXT: %[[X:.*]] = "tf.Relu"(%arg0) : (tensor<1xi8>) -> tensor<1xi8> - %0 = "tf.Relu"(%arg0) : (tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> - return %0: tensor<1x!tf.qint8> -} - -func @if_qint8(%arg0: tensor, %arg1: tensor<1x!tf.qint8>, %arg2: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> { - // CHECK: func @if_qint8(%arg0: tensor, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1xi8> - // CHECK-NEXT: %0 = "tf.IfRegion"(%arg0) ( { - // CHECK-NEXT: "tf.Yield"(%arg1) : (tensor<1xi8>) -> () - // CHECK-NEXT: }, { - // CHECK-NEXT: "tf.Yield"(%arg2) : (tensor<1xi8>) -> () - // CHECK-NEXT: }) {is_stateless = false} : (tensor) -> tensor<1xi8> - // CHECK-NEXT: return %0 : tensor<1xi8> - %0 = "tf.IfRegion"(%arg0) ( { - "tf.Yield"(%arg1) : (tensor<1x!tf.qint8>) -> () - }, { - "tf.Yield"(%arg2) : (tensor<1x!tf.qint8>) -> () - }) {is_stateless = false} : (tensor) -> tensor<1x!tf.qint8> - return %0 : tensor<1x!tf.qint8> -} - -func @id_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> { - // CHECK: func @id_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> { - // CHECK-NEXT: return %arg0 : tensor<1xi8> - return %arg0: tensor<1x!tf.qint8> -} - -func @id_qint16(%arg0: tensor<1x!tf.qint16>) -> tensor<1x!tf.qint16> { - // CHECK: func @id_qint16(%arg0: tensor<1xi16>) -> tensor<1xi16> { - // CHECK-NEXT: return %arg0 : tensor<1xi16> - return %arg0: tensor<1x!tf.qint16> -} - -func @id_qint32(%arg0: tensor<1x!tf.qint32>) -> tensor<1x!tf.qint32> { - // CHECK: func @id_qint32(%arg0: tensor<1xi32>) -> tensor<1xi32> { - // CHECK-NEXT: return %arg0 : tensor<1xi32> - return %arg0: tensor<1x!tf.qint32> -} - -func @id_quint8(%arg0: tensor<1x!tf.quint8>) -> tensor<1x!tf.quint8> { - // CHECK: func @id_quint8(%arg0: tensor<1xui8>) -> tensor<1xui8> { - // CHECK-NEXT: return %arg0 : tensor<1xui8> - return %arg0: tensor<1x!tf.quint8> -} - -func @id_quint16(%arg0: tensor<1x!tf.quint16>) -> tensor<1x!tf.quint16> { - // CHECK: func @id_quint16(%arg0: tensor<1xui16>) -> tensor<1xui16> { - // CHECK-NEXT: return %arg0 : tensor<1xui16> - return %arg0: tensor<1x!tf.quint16> -} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc deleted file mode 100644 index c1ce7d4aab9..00000000000 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// The TF dialect uses some TF types that are illegal in the MHLO dialect and -// some generic types that are legal in MHLO. This pass legalizes TF types into -// types that are legal in MHLO. For example, TF::Qint8Type is converted to i8. -// Rewrites here should run before TF to MHLO op legalizations are run. -// TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass -// rather than its own pass. - -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/TypeSwitch.h" -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes_detail.h" - -#define DEBUG_TYPE "xla-legalize-tf-types" - -namespace mlir { -namespace mhlo { -namespace { - -bool isIllegalElementType(Type type) { - return type - .isa(); -} - -Type replaceElementType(Type type) { - return TypeSwitch(type) - .Case([&type](Type) { - return mlir::IntegerType::get(type.getContext(), 8); - }) - .Case([&type](Type) { - return mlir::IntegerType::get(type.getContext(), 16); - }) - .Case([&type](Type) { - return mlir::IntegerType::get(type.getContext(), 32); - }) - .Case([&type](Type) { - return mlir::IntegerType::get( - type.getContext(), 8, - mlir::IntegerType::SignednessSemantics::Unsigned); - }) - .Case([&type](Type) { - return mlir::IntegerType::get( - type.getContext(), 16, - mlir::IntegerType::SignednessSemantics::Unsigned); - }) - .Default([&type](Type) { return type; }); -} - -// TODO(b/180234863): What's below this line is generic so convert it to a -// utility. - -bool isIllegalType(Type type) { - if (isIllegalElementType(type)) return true; - if (auto shaped = type.dyn_cast()) - return isIllegalType(shaped.getElementType()); - return false; -} - -Type replaceType(Type type) { - if (isIllegalElementType(type)) return replaceElementType(type); - if (auto shaped = type.dyn_cast()) { - Type elem = shaped.getElementType(); - if (isIllegalType(elem)) return shaped.clone(replaceType(elem)); - } - return type; -} - -// An Op is illegal iff it contains an illegalType. -class TfTypeConversionTarget : public ConversionTarget { - public: - explicit TfTypeConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { - markUnknownOpDynamicallyLegal(); - } - - protected: - bool isDynamicallyLegal(Operation *op) const override { - // The FuncOp type can contain types that the op's operand and result types - // do not contain. - if (auto func = dyn_cast(op)) { - if (llvm::any_of(func.getType().getInputs(), isIllegalType) || - llvm::any_of(func.getType().getResults(), isIllegalType)) - return false; - } - if (llvm::any_of(op->getOperandTypes(), isIllegalType) || - llvm::any_of(op->getResultTypes(), isIllegalType)) - return false; - return true; - } -}; - -class TfTypeConverter : public TypeConverter { - public: - TfTypeConverter() { - addConversion([](Type type) -> Type { - if (isIllegalType(type)) - return replaceType(type); - else - return type; - }); - } -}; - -class TfTypePattern : public ConversionPattern { - public: - TfTypePattern(MLIRContext *ctx, TypeConverter &converter) - : ConversionPattern(1, converter, MatchAnyOpTypeTag()) {} - - // The dialect conversion framework will call this matchAndRewrite on each - // Operation in the IR tree. This call matchAndRewrite needs to update the - // Operation's results and child regions. - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - // Update the results. - llvm::SmallVector new_results; - if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), - new_results))) - return failure(); - - // Update the regions. The dialect conversion framework wants new regions to - // be created and updated, rather than updating the old op. Thus we use an - // OperationState so we can add regions to the new up. - OperationState state(op->getLoc(), op->getName().getStringRef(), operands, - new_results, op->getAttrs(), op->getSuccessors()); - for (Region ®ion : op->getRegions()) { - Region &new_region = *state.addRegion(); - rewriter.inlineRegionBefore(region, new_region, new_region.begin()); - if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter()))) - return failure(); - } - rewriter.replaceOp(op, rewriter.createOperation(state)->getResults()); - - return success(); - } -}; - -struct LegalizeTfTypesPass - : public LegalizeTfTypesPassBase { - void runOnOperation() override; -}; - -void LegalizeTfTypesPass::runOnOperation() { - TfTypeConverter converter; - OwningRewritePatternList patterns; - patterns.insert(&getContext(), converter); - populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); - TfTypeConversionTarget target(getContext()); - if (failed(applyFullConversion(getOperation(), target, std::move(patterns)))) - return signalPassFailure(); -} - -static PassRegistration registration( - "xla-legalize-tf-types", - "Replace TensorFlow types with types that are legal in the MHLO dialect"); - -} // namespace - -std::unique_ptr> CreateLegalizeTfTypesPass() { - return std::make_unique(); -} - -} // namespace mhlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 77ba879e000..b5398f15089 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -49,10 +49,6 @@ std::unique_ptr> createLegalizeTFPass( std::unique_ptr> createLegalizeTfWithTf2XlaPass( llvm::StringRef device_type); -/// Replaces types that do not exist in MHLO with equivalent types that do -/// exist. -std::unique_ptr> CreateLegalizeTfTypesPass(); - /// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list. void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, OwningRewritePatternList& patterns); diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_passes.td b/tensorflow/compiler/mlir/xla/transforms/xla_passes.td index e93634b63b9..602740cbfb9 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_passes.td +++ b/tensorflow/compiler/mlir/xla/transforms/xla_passes.td @@ -15,24 +15,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def LegalizeTfTypesPass : Pass<"xla-legalize-tf-types"> { - let summary = "Replace TensorFlow types with types that are legal in the MHLO dialect"; - - let description = [{ -The TF dialect uses some TF types that are illegal in the MHLO dialect and -some generic types that are legal in MHLO. This pass legalizes TF types into -types that are legal in MHLO. Rewrites here should run before TF to MHLO op -legalizations are run. - -Specifically, this pass replaces each quantized integer type with the -corresponding ordinary types. For example, `TF::Qint8Type` is replaced with `i8` -everywhere it occurs. Types that are replaced are `TF::Qint8Type`, -`TF::Qint16Type`, `TF::Qint32Type`, `TF::Quint8Type`, and `TF::Quint16Type`. - }]; - - let constructor = "::mlir::mhlo::CreateLegalizeTfTypesPass()"; -} - def PrepareForExportPass : FunctionPass<"xla-prepare-for-export"> { let summary = "Prepare for XLA export";