Added pass xla-legalize-tf-types. This pass converts quantized types to non-quantized (e.g. qint8 to i8).
PiperOrigin-RevId: 358217198 Change-Id: I0815555f5d0f7c00121825fd4bc75e5016b3ce17
This commit is contained in:
parent
bd4af92247
commit
b6d39383aa
@ -307,7 +307,6 @@ void CreateConvertMlirToXlaHloPipeline(
|
||||
// inside PromoteResourcesToArgs.
|
||||
pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
|
||||
|
||||
pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
|
||||
/*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
|
||||
/*tf2xla_fallback_device_type=*/device_type));
|
||||
|
@ -72,7 +72,6 @@ gentbl(
|
||||
cc_library(
|
||||
name = "xla_passes",
|
||||
srcs = [
|
||||
"transforms/legalize_tf_types.cc",
|
||||
"transforms/passes_detail.h",
|
||||
"transforms/prepare_for_export.cc",
|
||||
],
|
||||
@ -82,12 +81,10 @@ cc_library(
|
||||
deps = [
|
||||
":xla_passes_inc_gen",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -112,7 +109,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo:convert_op_folder",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:padding",
|
||||
|
@ -1,54 +0,0 @@
|
||||
// RUN: tf-opt "-xla-legalize-tf-types" %s | FILECHECK_OPTS="" FileCheck %s
|
||||
|
||||
func @relu_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
|
||||
// CHECK: func @relu_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> {
|
||||
// CHECK-NEXT: %[[X:.*]] = "tf.Relu"(%arg0) : (tensor<1xi8>) -> tensor<1xi8>
|
||||
%0 = "tf.Relu"(%arg0) : (tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8>
|
||||
return %0: tensor<1x!tf.qint8>
|
||||
}
|
||||
|
||||
func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1x!tf.qint8>, %arg2: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
|
||||
// CHECK: func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1xi8>
|
||||
// CHECK-NEXT: %0 = "tf.IfRegion"(%arg0) ( {
|
||||
// CHECK-NEXT: "tf.Yield"(%arg1) : (tensor<1xi8>) -> ()
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: "tf.Yield"(%arg2) : (tensor<1xi8>) -> ()
|
||||
// CHECK-NEXT: }) {is_stateless = false} : (tensor<i1>) -> tensor<1xi8>
|
||||
// CHECK-NEXT: return %0 : tensor<1xi8>
|
||||
%0 = "tf.IfRegion"(%arg0) ( {
|
||||
"tf.Yield"(%arg1) : (tensor<1x!tf.qint8>) -> ()
|
||||
}, {
|
||||
"tf.Yield"(%arg2) : (tensor<1x!tf.qint8>) -> ()
|
||||
}) {is_stateless = false} : (tensor<i1>) -> tensor<1x!tf.qint8>
|
||||
return %0 : tensor<1x!tf.qint8>
|
||||
}
|
||||
|
||||
func @id_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
|
||||
// CHECK: func @id_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xi8>
|
||||
return %arg0: tensor<1x!tf.qint8>
|
||||
}
|
||||
|
||||
func @id_qint16(%arg0: tensor<1x!tf.qint16>) -> tensor<1x!tf.qint16> {
|
||||
// CHECK: func @id_qint16(%arg0: tensor<1xi16>) -> tensor<1xi16> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xi16>
|
||||
return %arg0: tensor<1x!tf.qint16>
|
||||
}
|
||||
|
||||
func @id_qint32(%arg0: tensor<1x!tf.qint32>) -> tensor<1x!tf.qint32> {
|
||||
// CHECK: func @id_qint32(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xi32>
|
||||
return %arg0: tensor<1x!tf.qint32>
|
||||
}
|
||||
|
||||
func @id_quint8(%arg0: tensor<1x!tf.quint8>) -> tensor<1x!tf.quint8> {
|
||||
// CHECK: func @id_quint8(%arg0: tensor<1xui8>) -> tensor<1xui8> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xui8>
|
||||
return %arg0: tensor<1x!tf.quint8>
|
||||
}
|
||||
|
||||
func @id_quint16(%arg0: tensor<1x!tf.quint16>) -> tensor<1x!tf.quint16> {
|
||||
// CHECK: func @id_quint16(%arg0: tensor<1xui16>) -> tensor<1xui16> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xui16>
|
||||
return %arg0: tensor<1x!tf.quint16>
|
||||
}
|
@ -1,185 +0,0 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// The TF dialect uses some TF types that are illegal in the MHLO dialect and
|
||||
// some generic types that are legal in MHLO. This pass legalizes TF types into
|
||||
// types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
|
||||
// Rewrites here should run before TF to MHLO op legalizations are run.
|
||||
// TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
|
||||
// rather than its own pass.
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes_detail.h"
|
||||
|
||||
#define DEBUG_TYPE "xla-legalize-tf-types"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
bool isIllegalElementType(Type type) {
|
||||
return type
|
||||
.isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
|
||||
mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
|
||||
}
|
||||
|
||||
Type replaceElementType(Type type) {
|
||||
return TypeSwitch<Type, Type>(type)
|
||||
.Case<mlir::TF::Qint8Type>([&type](Type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 8);
|
||||
})
|
||||
.Case<mlir::TF::Qint16Type>([&type](Type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 16);
|
||||
})
|
||||
.Case<mlir::TF::Qint32Type>([&type](Type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 32);
|
||||
})
|
||||
.Case<mlir::TF::Quint8Type>([&type](Type) {
|
||||
return mlir::IntegerType::get(
|
||||
type.getContext(), 8,
|
||||
mlir::IntegerType::SignednessSemantics::Unsigned);
|
||||
})
|
||||
.Case<mlir::TF::Quint16Type>([&type](Type) {
|
||||
return mlir::IntegerType::get(
|
||||
type.getContext(), 16,
|
||||
mlir::IntegerType::SignednessSemantics::Unsigned);
|
||||
})
|
||||
.Default([&type](Type) { return type; });
|
||||
}
|
||||
|
||||
// TODO(b/180234863): What's below this line is generic so convert it to a
|
||||
// utility.
|
||||
|
||||
bool isIllegalType(Type type) {
|
||||
if (isIllegalElementType(type)) return true;
|
||||
if (auto shaped = type.dyn_cast<ShapedType>())
|
||||
return isIllegalType(shaped.getElementType());
|
||||
return false;
|
||||
}
|
||||
|
||||
Type replaceType(Type type) {
|
||||
if (isIllegalElementType(type)) return replaceElementType(type);
|
||||
if (auto shaped = type.dyn_cast<ShapedType>()) {
|
||||
Type elem = shaped.getElementType();
|
||||
if (isIllegalType(elem)) return shaped.clone(replaceType(elem));
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
// An Op is illegal iff it contains an illegalType.
|
||||
class TfTypeConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TfTypeConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
|
||||
markUnknownOpDynamicallyLegal();
|
||||
}
|
||||
|
||||
protected:
|
||||
bool isDynamicallyLegal(Operation *op) const override {
|
||||
// The FuncOp type can contain types that the op's operand and result types
|
||||
// do not contain.
|
||||
if (auto func = dyn_cast<FuncOp>(op)) {
|
||||
if (llvm::any_of(func.getType().getInputs(), isIllegalType) ||
|
||||
llvm::any_of(func.getType().getResults(), isIllegalType))
|
||||
return false;
|
||||
}
|
||||
if (llvm::any_of(op->getOperandTypes(), isIllegalType) ||
|
||||
llvm::any_of(op->getResultTypes(), isIllegalType))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class TfTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TfTypeConverter() {
|
||||
addConversion([](Type type) -> Type {
|
||||
if (isIllegalType(type))
|
||||
return replaceType(type);
|
||||
else
|
||||
return type;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
class TfTypePattern : public ConversionPattern {
|
||||
public:
|
||||
TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
|
||||
: ConversionPattern(1, converter, MatchAnyOpTypeTag()) {}
|
||||
|
||||
// The dialect conversion framework will call this matchAndRewrite on each
|
||||
// Operation in the IR tree. This call matchAndRewrite needs to update the
|
||||
// Operation's results and child regions.
|
||||
LogicalResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Update the results.
|
||||
llvm::SmallVector<Type, 4> new_results;
|
||||
if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
|
||||
new_results)))
|
||||
return failure();
|
||||
|
||||
// Update the regions. The dialect conversion framework wants new regions to
|
||||
// be created and updated, rather than updating the old op. Thus we use an
|
||||
// OperationState so we can add regions to the new up.
|
||||
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
|
||||
new_results, op->getAttrs(), op->getSuccessors());
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
Region &new_region = *state.addRegion();
|
||||
rewriter.inlineRegionBefore(region, new_region, new_region.begin());
|
||||
if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
|
||||
return failure();
|
||||
}
|
||||
rewriter.replaceOp(op, rewriter.createOperation(state)->getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct LegalizeTfTypesPass
|
||||
: public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void LegalizeTfTypesPass::runOnOperation() {
|
||||
TfTypeConverter converter;
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<TfTypePattern>(&getContext(), converter);
|
||||
populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
|
||||
TfTypeConversionTarget target(getContext());
|
||||
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTfTypesPass> registration(
|
||||
"xla-legalize-tf-types",
|
||||
"Replace TensorFlow types with types that are legal in the MHLO dialect");
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
|
||||
return std::make_unique<LegalizeTfTypesPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
@ -49,10 +49,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
|
||||
llvm::StringRef device_type);
|
||||
|
||||
/// Replaces types that do not exist in MHLO with equivalent types that do
|
||||
/// exist.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTfTypesPass();
|
||||
|
||||
/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
|
||||
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
|
||||
OwningRewritePatternList& patterns);
|
||||
|
@ -15,24 +15,6 @@ limitations under the License.
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def LegalizeTfTypesPass : Pass<"xla-legalize-tf-types"> {
|
||||
let summary = "Replace TensorFlow types with types that are legal in the MHLO dialect";
|
||||
|
||||
let description = [{
|
||||
The TF dialect uses some TF types that are illegal in the MHLO dialect and
|
||||
some generic types that are legal in MHLO. This pass legalizes TF types into
|
||||
types that are legal in MHLO. Rewrites here should run before TF to MHLO op
|
||||
legalizations are run.
|
||||
|
||||
Specifically, this pass replaces each quantized integer type with the
|
||||
corresponding ordinary types. For example, `TF::Qint8Type` is replaced with `i8`
|
||||
everywhere it occurs. Types that are replaced are `TF::Qint8Type`,
|
||||
`TF::Qint16Type`, `TF::Qint32Type`, `TF::Quint8Type`, and `TF::Quint16Type`.
|
||||
}];
|
||||
|
||||
let constructor = "::mlir::mhlo::CreateLegalizeTfTypesPass()";
|
||||
}
|
||||
|
||||
def PrepareForExportPass : FunctionPass<"xla-prepare-for-export"> {
|
||||
let summary = "Prepare for XLA export";
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user