Added pass xla-legalize-tf-types. This pass converts quantized types to non-quantized (e.g. qint8 to i8).

PiperOrigin-RevId: 358217198
Change-Id: I0815555f5d0f7c00121825fd4bc75e5016b3ce17
This commit is contained in:
Yujing Zhang 2021-02-18 10:59:55 -08:00 committed by TensorFlower Gardener
parent bd4af92247
commit b6d39383aa
6 changed files with 0 additions and 266 deletions

View File

@ -307,7 +307,6 @@ void CreateConvertMlirToXlaHloPipeline(
// inside PromoteResourcesToArgs.
pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass());
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
/*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
/*tf2xla_fallback_device_type=*/device_type));

View File

@ -72,7 +72,6 @@ gentbl(
cc_library(
name = "xla_passes",
srcs = [
"transforms/legalize_tf_types.cc",
"transforms/passes_detail.h",
"transforms/prepare_for_export.cc",
],
@ -82,12 +81,10 @@ cc_library(
deps = [
":xla_passes_inc_gen",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
@ -112,7 +109,6 @@ cc_library(
"//tensorflow/compiler/mlir/hlo:convert_op_folder",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:padding",

View File

@ -1,54 +0,0 @@
// RUN: tf-opt "-xla-legalize-tf-types" %s | FILECHECK_OPTS="" FileCheck %s
func @relu_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
// CHECK: func @relu_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> {
// CHECK-NEXT: %[[X:.*]] = "tf.Relu"(%arg0) : (tensor<1xi8>) -> tensor<1xi8>
%0 = "tf.Relu"(%arg0) : (tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8>
return %0: tensor<1x!tf.qint8>
}
func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1x!tf.qint8>, %arg2: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
// CHECK: func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1xi8>
// CHECK-NEXT: %0 = "tf.IfRegion"(%arg0) ( {
// CHECK-NEXT: "tf.Yield"(%arg1) : (tensor<1xi8>) -> ()
// CHECK-NEXT: }, {
// CHECK-NEXT: "tf.Yield"(%arg2) : (tensor<1xi8>) -> ()
// CHECK-NEXT: }) {is_stateless = false} : (tensor<i1>) -> tensor<1xi8>
// CHECK-NEXT: return %0 : tensor<1xi8>
%0 = "tf.IfRegion"(%arg0) ( {
"tf.Yield"(%arg1) : (tensor<1x!tf.qint8>) -> ()
}, {
"tf.Yield"(%arg2) : (tensor<1x!tf.qint8>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<1x!tf.qint8>
return %0 : tensor<1x!tf.qint8>
}
func @id_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
// CHECK: func @id_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> {
// CHECK-NEXT: return %arg0 : tensor<1xi8>
return %arg0: tensor<1x!tf.qint8>
}
func @id_qint16(%arg0: tensor<1x!tf.qint16>) -> tensor<1x!tf.qint16> {
// CHECK: func @id_qint16(%arg0: tensor<1xi16>) -> tensor<1xi16> {
// CHECK-NEXT: return %arg0 : tensor<1xi16>
return %arg0: tensor<1x!tf.qint16>
}
func @id_qint32(%arg0: tensor<1x!tf.qint32>) -> tensor<1x!tf.qint32> {
// CHECK: func @id_qint32(%arg0: tensor<1xi32>) -> tensor<1xi32> {
// CHECK-NEXT: return %arg0 : tensor<1xi32>
return %arg0: tensor<1x!tf.qint32>
}
func @id_quint8(%arg0: tensor<1x!tf.quint8>) -> tensor<1x!tf.quint8> {
// CHECK: func @id_quint8(%arg0: tensor<1xui8>) -> tensor<1xui8> {
// CHECK-NEXT: return %arg0 : tensor<1xui8>
return %arg0: tensor<1x!tf.quint8>
}
func @id_quint16(%arg0: tensor<1x!tf.quint16>) -> tensor<1x!tf.quint16> {
// CHECK: func @id_quint16(%arg0: tensor<1xui16>) -> tensor<1xui16> {
// CHECK-NEXT: return %arg0 : tensor<1xui16>
return %arg0: tensor<1x!tf.quint16>
}

View File

@ -1,185 +0,0 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// The TF dialect uses some TF types that are illegal in the MHLO dialect and
// some generic types that are legal in MHLO. This pass legalizes TF types into
// types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
// Rewrites here should run before TF to MHLO op legalizations are run.
// TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
// rather than its own pass.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes_detail.h"
#define DEBUG_TYPE "xla-legalize-tf-types"
namespace mlir {
namespace mhlo {
namespace {
bool isIllegalElementType(Type type) {
return type
.isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
}
Type replaceElementType(Type type) {
return TypeSwitch<Type, Type>(type)
.Case<mlir::TF::Qint8Type>([&type](Type) {
return mlir::IntegerType::get(type.getContext(), 8);
})
.Case<mlir::TF::Qint16Type>([&type](Type) {
return mlir::IntegerType::get(type.getContext(), 16);
})
.Case<mlir::TF::Qint32Type>([&type](Type) {
return mlir::IntegerType::get(type.getContext(), 32);
})
.Case<mlir::TF::Quint8Type>([&type](Type) {
return mlir::IntegerType::get(
type.getContext(), 8,
mlir::IntegerType::SignednessSemantics::Unsigned);
})
.Case<mlir::TF::Quint16Type>([&type](Type) {
return mlir::IntegerType::get(
type.getContext(), 16,
mlir::IntegerType::SignednessSemantics::Unsigned);
})
.Default([&type](Type) { return type; });
}
// TODO(b/180234863): What's below this line is generic so convert it to a
// utility.
bool isIllegalType(Type type) {
if (isIllegalElementType(type)) return true;
if (auto shaped = type.dyn_cast<ShapedType>())
return isIllegalType(shaped.getElementType());
return false;
}
Type replaceType(Type type) {
if (isIllegalElementType(type)) return replaceElementType(type);
if (auto shaped = type.dyn_cast<ShapedType>()) {
Type elem = shaped.getElementType();
if (isIllegalType(elem)) return shaped.clone(replaceType(elem));
}
return type;
}
// An Op is illegal iff it contains an illegalType.
class TfTypeConversionTarget : public ConversionTarget {
public:
explicit TfTypeConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
markUnknownOpDynamicallyLegal();
}
protected:
bool isDynamicallyLegal(Operation *op) const override {
// The FuncOp type can contain types that the op's operand and result types
// do not contain.
if (auto func = dyn_cast<FuncOp>(op)) {
if (llvm::any_of(func.getType().getInputs(), isIllegalType) ||
llvm::any_of(func.getType().getResults(), isIllegalType))
return false;
}
if (llvm::any_of(op->getOperandTypes(), isIllegalType) ||
llvm::any_of(op->getResultTypes(), isIllegalType))
return false;
return true;
}
};
class TfTypeConverter : public TypeConverter {
public:
TfTypeConverter() {
addConversion([](Type type) -> Type {
if (isIllegalType(type))
return replaceType(type);
else
return type;
});
}
};
class TfTypePattern : public ConversionPattern {
public:
TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
: ConversionPattern(1, converter, MatchAnyOpTypeTag()) {}
// The dialect conversion framework will call this matchAndRewrite on each
// Operation in the IR tree. This call matchAndRewrite needs to update the
// Operation's results and child regions.
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Update the results.
llvm::SmallVector<Type, 4> new_results;
if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
new_results)))
return failure();
// Update the regions. The dialect conversion framework wants new regions to
// be created and updated, rather than updating the old op. Thus we use an
// OperationState so we can add regions to the new up.
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
new_results, op->getAttrs(), op->getSuccessors());
for (Region &region : op->getRegions()) {
Region &new_region = *state.addRegion();
rewriter.inlineRegionBefore(region, new_region, new_region.begin());
if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
return failure();
}
rewriter.replaceOp(op, rewriter.createOperation(state)->getResults());
return success();
}
};
struct LegalizeTfTypesPass
: public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
void runOnOperation() override;
};
void LegalizeTfTypesPass::runOnOperation() {
TfTypeConverter converter;
OwningRewritePatternList patterns;
patterns.insert<TfTypePattern>(&getContext(), converter);
populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
TfTypeConversionTarget target(getContext());
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();
}
static PassRegistration<LegalizeTfTypesPass> registration(
"xla-legalize-tf-types",
"Replace TensorFlow types with types that are legal in the MHLO dialect");
} // namespace
std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
return std::make_unique<LegalizeTfTypesPass>();
}
} // namespace mhlo
} // namespace mlir

View File

@ -49,10 +49,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
llvm::StringRef device_type);
/// Replaces types that do not exist in MHLO with equivalent types that do
/// exist.
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTfTypesPass();
/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
OwningRewritePatternList& patterns);

View File

@ -15,24 +15,6 @@ limitations under the License.
include "mlir/Pass/PassBase.td"
def LegalizeTfTypesPass : Pass<"xla-legalize-tf-types"> {
let summary = "Replace TensorFlow types with types that are legal in the MHLO dialect";
let description = [{
The TF dialect uses some TF types that are illegal in the MHLO dialect and
some generic types that are legal in MHLO. This pass legalizes TF types into
types that are legal in MHLO. Rewrites here should run before TF to MHLO op
legalizations are run.
Specifically, this pass replaces each quantized integer type with the
corresponding ordinary types. For example, `TF::Qint8Type` is replaced with `i8`
everywhere it occurs. Types that are replaced are `TF::Qint8Type`,
`TF::Qint16Type`, `TF::Qint32Type`, `TF::Quint8Type`, and `TF::Quint16Type`.
}];
let constructor = "::mlir::mhlo::CreateLegalizeTfTypesPass()";
}
def PrepareForExportPass : FunctionPass<"xla-prepare-for-export"> {
let summary = "Prepare for XLA export";