Merge branch 'master' into amankishore/language-fix

This commit is contained in:
Aman Kishore 2020-12-02 18:40:29 -05:00 committed by GitHub
commit 2e40a04526
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
612 changed files with 29101 additions and 6621 deletions
README.md
tensorflow
BUILD
compiler/mlir
BUILD
hlo
lite
mlir_graph_optimization_pass.h
python
mlir.cc
mlir_wrapper
tensorflow

View File

@ -155,6 +155,7 @@ Container Type | Status | Art
* [DeepLearning.AI TensorFlow Developer Professional Certificate](https://www.coursera.org/specializations/tensorflow-in-practice)
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
* [Intro to TensorFlow for A.I, M.L, and D.L from Coursera](https://www.coursera.org/learn/introduction-tensorflow)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)

View File

@ -728,8 +728,8 @@ tf_cc_shared_object(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
"//tensorflow/c:kernels_hdrs",
"//tensorflow/c/experimental/stream_executor:stream_executor",
"//tensorflow/c:kernels",
"//tensorflow/c:logging",
"//tensorflow/c:ops_hdrs",
"//tensorflow/cc/saved_model:loader_lite_impl",

View File

@ -112,6 +112,8 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
"//tensorflow/compiler/mlir/tosa:tf_tosa_passes",
"//tensorflow/compiler/mlir/tosa:tfl_tosa_passes",
],
)

View File

@ -45,6 +45,7 @@ filegroup(
"include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
@ -170,6 +171,21 @@ gentbl(
td_srcs = [":hlo_ops_td_files"],
)
gentbl(
name = "hlo_ops_base_enums_inc_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
("-gen-enum-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc"),
("-gen-enum-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.cc.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td",
td_relative_includes = [
"include",
],
td_srcs = [":hlo_ops_td_files"],
)
gentbl(
name = "hlo_ops_pattern_gen",
compatible_with = get_compatible_with_cloud(),
@ -230,6 +246,22 @@ gentbl(
],
)
gentbl(
name = "lhlo_gpu_ops_enums_inc_gen",
compatible_with = get_compatible_with_cloud(),
strip_include_prefix = "include",
tbl_outs = [
("-gen-enum-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h.inc"),
("-gen-enum-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td",
td_relative_includes = [
"include",
],
td_srcs = [":hlo_ops_td_files"],
)
cc_library(
name = "lhlo_gpu_ops_structs",
srcs = [
@ -248,6 +280,23 @@ cc_library(
],
)
cc_library(
name = "lhlo_gpu_ops_enums",
srcs = [
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc.inc",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h.inc",
"lib/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc",
],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h",
],
includes = ["include"],
deps = [
":lhlo_gpu_ops_enums_inc_gen",
"@llvm-project//llvm:Support",
],
)
gentbl(
name = "lhlo_gpu_ops_inc_gen",
compatible_with = get_compatible_with_cloud(),
@ -265,6 +314,7 @@ gentbl(
":hlo_ops_td_files",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td",
],
)
@ -342,6 +392,22 @@ cc_library(
],
)
cc_library(
name = "hlo_ops_base_enums",
srcs = [
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc",
"lib/Dialect/mhlo/IR/hlo_ops_base_enums.cc",
],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h",
],
includes = ["include"],
deps = [
":hlo_ops_base_enums_inc_gen",
"@llvm-project//llvm:Support",
],
)
cc_library(
name = "convert_op_folder",
srcs = ["lib/utils/convert_op_folder.cc"],
@ -374,6 +440,7 @@ cc_library(
":canonicalize_inc_gen",
":chlo_ops_inc_gen",
":convert_op_folder",
":hlo_ops_base_enums",
":hlo_ops_base_inc_gen",
":hlo_ops_base_structs",
":hlo_ops_inc_gen",
@ -405,6 +472,7 @@ cc_library(
],
includes = ["include"],
deps = [
":hlo_ops_base_enums",
":hlo_ops_base_inc_gen",
":hlo_ops_base_structs",
":lhlo_ops_inc_gen",
@ -436,8 +504,10 @@ cc_library(
includes = ["include"],
deps = [
":hlo",
":hlo_ops_base_enums",
":hlo_ops_base_structs",
":infer_fusibility_op_interface",
":lhlo_gpu_ops_enums",
":lhlo_gpu_ops_inc_gen",
":lhlo_gpu_ops_structs",
"@llvm-project//llvm:Support",

View File

@ -32,6 +32,8 @@ mlir_tablegen(hlo_ops.h.inc -gen-op-decls)
mlir_tablegen(hlo_ops.cc.inc -gen-op-defs)
mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs)
mlir_tablegen(hlo_ops_base_enums.h.inc -gen-enum-decls)
mlir_tablegen(hlo_ops_base_enums.cc.inc -gen-enum-defs)
add_public_tablegen_target(MLIRhlo_opsIncGen)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td)
@ -40,6 +42,9 @@ mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td)
mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_enums.td)
mlir_tablegen(lhlo_gpu_ops_enums.h.inc -gen-enum-decls)
mlir_tablegen(lhlo_gpu_ops_enums.cc.inc -gen-enum-defs)
add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen)
add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen)

View File

@ -34,6 +34,7 @@ limitations under the License.
// clang-format off
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
// clang-format on

View File

@ -41,7 +41,9 @@ def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">;
def HLO_CUSTOM_FUSION : StrEnumAttrCase<"kCustom">;
def HLO_FusionKindAttr : StrEnumAttr<"FusionKind", "fusion kind", [
HLO_LOOP_FUSION, HLO_INPUT_FUSION, HLO_OUTPUT_FUSION, HLO_CUSTOM_FUSION
]>;
]> {
let cppNamespace = "::mlir::mhlo";
}
//===----------------------------------------------------------------------===//
// MHLO nullary op definitions.
@ -896,7 +898,7 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
(ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs),
ConvolutionAttributes<HLO_Dialect>.attributes);
ConvolutionAttributes.attributes);
let results = (outs HLO_Tensor);
}

View File

@ -23,6 +23,7 @@ def HLO_Dialect : Dialect {
let cppNamespace = "::mlir::mhlo";
}
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td"
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
@ -692,77 +693,7 @@ class BASE_HLO_TupleOp {
}];
}
//===----------------------------------------------------------------------===//
// Precision Config enum definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA PrecisionConfig proto enum.
def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">;
def HLO_PRECISION_HIGH : StrEnumAttrCase<"HIGH">;
def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">;
def HLO_PrecisionAttr : StrEnumAttr<"Precision",
"XLA precision for an operand. Has backend specific meaning.",
[HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]>;
// TODO(b/129153247) See if it's possible to also validate the size.
def HLO_PrecisionConfigAttr:
OptionalAttr<
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
//===----------------------------------------------------------------------===//
// Fast Fourier Transform Type enum definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA FftType proto enum.
def HLO_FFT_TYPE_FFT : StrEnumAttrCase<"FFT">;
def HLO_FFT_TYPE_IFFT : StrEnumAttrCase<"IFFT">;
def HLO_FFT_TYPE_RFFT : StrEnumAttrCase<"RFFT">;
def HLO_FFT_TYPE_IRFFT : StrEnumAttrCase<"IRFFT">;
def HLO_FftTypeAttr : StrEnumAttr<"FftType",
"XLA fast fourier transform type.",
[HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT,
HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]>;
//===----------------------------------------------------------------------===//
// Comparison op definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA ComparisonDirection enum.
def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">;
def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">;
def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">;
def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">;
def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">;
def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">;
def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
"Which comparison operation to perform.",
[
HLO_COMPARISON_DIRECTION_EQ,
HLO_COMPARISON_DIRECTION_NE,
HLO_COMPARISON_DIRECTION_GE,
HLO_COMPARISON_DIRECTION_GT,
HLO_COMPARISON_DIRECTION_LE,
HLO_COMPARISON_DIRECTION_LT
]>;
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
"Which comparison type to use.",
[
HLO_COMPARISON_TYPE_FLOAT,
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
HLO_COMPARISON_TYPE_SIGNED,
HLO_COMPARISON_TYPE_UNSIGNED
]>;
class BASE_HLO_CompareOp {
@ -783,13 +714,6 @@ class BASE_HLO_CompareOp {
// Quantize op definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA ComparisonDirection enum.
def HLO_MIN_COMBINED : StrEnumAttrCase<"MIN_COMBINED">;
def HLO_DequantizeModeAttr : StrEnumAttr<"DequantizeMode",
"Dequantization mode. Only MIN_COMBINED is supported.",
[HLO_MIN_COMBINED]>;
class BASE_HLO_DequantizeOp {
string summary = "Dequantize operator";
@ -1029,7 +953,12 @@ class BASE_HLO_ConcatenateOp {
// Common convolution attributes
//===----------------------------------------------------------------------===//
class ConvolutionAttributes<Dialect dialect> {
// TODO(b/129153247) See if it's possible to also validate the size.
def HLO_PrecisionConfigAttr:
OptionalAttr<
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
def ConvolutionAttributes {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
@ -1270,21 +1199,6 @@ class BASE_HLO_TransposeOp {
}];
}
// These mirror the XLA Transpose enum in Triangular Solve options.
def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">;
def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">;
def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">;
def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">;
def HLO_TransposeAttr : StrEnumAttr<"Transpose",
"Transpose options",
[
HLO_TRANSPOSE_INVALID,
HLO_NO_TRANSPOSE,
HLO_TRANSPOSE,
HLO_ADJOINT
]>;
class BASE_HLO_TriangularSolveOp {
string summary = "TriangularSolve operator";

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines enums used in MHLO and LMHLO.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
// Order matters, this .inc header is not self-contained, and relies on the
// #includes above.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_

View File

@ -0,0 +1,119 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef HLO_OPS_BASE_ENUMS
#define HLO_OPS_BASE_ENUMS
//===----------------------------------------------------------------------===//
// Precision Config enum definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA PrecisionConfig proto enum.
def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">;
def HLO_PRECISION_HIGH : StrEnumAttrCase<"HIGH">;
def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">;
def HLO_PrecisionAttr : StrEnumAttr<"Precision",
"XLA precision for an operand. Has backend specific meaning.",
[HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]> {
let cppNamespace = "::mlir::mhlo";
}
//===----------------------------------------------------------------------===//
// Fast Fourier Transform Type enum definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA FftType proto enum.
def HLO_FFT_TYPE_FFT : StrEnumAttrCase<"FFT">;
def HLO_FFT_TYPE_IFFT : StrEnumAttrCase<"IFFT">;
def HLO_FFT_TYPE_RFFT : StrEnumAttrCase<"RFFT">;
def HLO_FFT_TYPE_IRFFT : StrEnumAttrCase<"IRFFT">;
def HLO_FftTypeAttr : StrEnumAttr<"FftType",
"XLA fast fourier transform type.",
[HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT,
HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]> {
let cppNamespace = "::mlir::mhlo";
}
//===----------------------------------------------------------------------===//
// Comparison op definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA ComparisonDirection enum.
def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">;
def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">;
def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">;
def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">;
def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">;
def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">;
def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
"Which comparison operation to perform.",
[
HLO_COMPARISON_DIRECTION_EQ,
HLO_COMPARISON_DIRECTION_NE,
HLO_COMPARISON_DIRECTION_GE,
HLO_COMPARISON_DIRECTION_GT,
HLO_COMPARISON_DIRECTION_LE,
HLO_COMPARISON_DIRECTION_LT
]> {
let cppNamespace = "::mlir::mhlo";
}
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
"Which comparison type to use.",
[
HLO_COMPARISON_TYPE_FLOAT,
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
HLO_COMPARISON_TYPE_SIGNED,
HLO_COMPARISON_TYPE_UNSIGNED
]> {
let cppNamespace = "::mlir::mhlo";
}
// These mirror the XLA Dequantize mode string enum.
def HLO_MIN_COMBINED : StrEnumAttrCase<"MIN_COMBINED">;
def HLO_DequantizeModeAttr : StrEnumAttr<"DequantizeMode",
"Dequantization mode. Only MIN_COMBINED is supported.",
[HLO_MIN_COMBINED]> {
let cppNamespace = "::mlir::mhlo";
}
// These mirror the XLA Transpose enum in Triangular Solve options.
def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">;
def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">;
def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">;
def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">;
def HLO_TransposeAttr : StrEnumAttr<"Transpose",
"Transpose options",
[
HLO_TRANSPOSE_INVALID,
HLO_NO_TRANSPOSE,
HLO_TRANSPOSE,
HLO_ADJOINT
]> {
let cppNamespace = "::mlir::mhlo";
}
#endif // HLO_OPS_BASE_ENUMS

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"

View File

@ -23,9 +23,9 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td"
class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> :
Op<LHLO_GPU_Dialect, mnemonic,
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
@ -92,30 +92,16 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
// LMHLO ops representing convolution library functions.
//===----------------------------------------------------------------------===//
def ActivationModeNone : StrEnumAttrCase<"None">;
def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">;
def ActivationModeTanh : StrEnumAttrCase<"Relu">;
def ActivationModeRelu : StrEnumAttrCase<"Relu">;
def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">;
def ActivationModeReluX : StrEnumAttrCase<"ReluX">;
def ActivationModeBandPass : StrEnumAttrCase<"BandPass">;
def ActivationAttr : StrEnumAttr<"Activation",
"Activation applied with fused convolution",
[ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh,
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
ActivationModeBandPass]>;
def GpuConvolutionAttributes {
dag attributes = !con(
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
ConvolutionAttributes.attributes,
(ins F64Attr:$result_scale),
(ins ConvolutionBackendConfigAttr:$backend_config));
}
def GpuFusedConvolutionAttributes {
dag attributes = !con(
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
ConvolutionAttributes.attributes,
(ins F64Attr:$result_scale,
ActivationAttr:$activation_mode,
F64Attr:$side_input_scale),
@ -179,9 +165,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbers:$dot_dimension_numbers,
F64Attr:$alpha,
F64Attr:$alpha_real,
F64Attr:$alpha_imag,
I64Attr:$batch_size,
I64Attr:$algorithm);
OptionalAttr<I64Attr>:$algorithm);
}
// output = alpha(lhs * rhs) + beta * bias
@ -192,10 +179,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbers:$dot_dimension_numbers,
F64Attr:$alpha,
F64Attr:$alpha_real,
F64Attr:$alpha_imag,
F64Attr:$beta,
I64Attr:$batch_size,
I64Attr:$algorithm);
OptionalAttr<I64Attr>:$algorithm);
}
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ==============================================================================*/
// This file defines enums used in the LMHLO_GPU dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
// Order matters, this .inc header is not self-contained, and relies on the
// #includes above.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LHLO_GPU_OPS_ENUMS
#define LHLO_GPU_OPS_ENUMS
include "mlir/IR/OpBase.td"
def ActivationModeNone : StrEnumAttrCase<"None">;
def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">;
def ActivationModeTanh : StrEnumAttrCase<"Tanh">;
def ActivationModeRelu : StrEnumAttrCase<"Relu">;
def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">;
def ActivationModeReluX : StrEnumAttrCase<"ReluX">;
def ActivationModeBandPass : StrEnumAttrCase<"BandPass">;
def ActivationAttr : StrEnumAttr<"Activation",
"Activation applied with fused convolution",
[ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh,
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
ActivationModeBandPass]> {
let cppNamespace = "::mlir::lmhlo_gpu";
}
#endif // LHLO_GPU_OPS_ENUMS

View File

@ -1,4 +1,3 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -415,7 +415,7 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
ConvolutionAttributes<LHLO_Dialect>.attributes);
ConvolutionAttributes.attributes);
}
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
@ -654,7 +654,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
let skipDefaultBuilders = 1;
let builders = [
OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
OpBuilderDAG<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
let extraClassDeclaration = [{

View File

@ -44,6 +44,7 @@ add_mlir_library(MhloInferFusibilityOpInterface
add_mlir_dialect_library(MhloDialect
hlo_ops.cc
hlo_ops_base_structs.cc
hlo_ops_base_enums.cc
DEPENDS
MLIRhlo_opsIncGen
@ -70,6 +71,7 @@ target_link_libraries(LmhloDialect PUBLIC MLIRIR)
add_mlir_dialect_library(LmhloGPUDialect
lhlo_gpu_ops.cc
lhlo_gpu_ops_structs.cc
lhlo_gpu_ops_enums.cc
DEPENDS
MLIRlhlo_gpu_opsIncGen

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.cc.inc"

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc.inc"

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "mlir/IR/Block.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"

View File

@ -16,7 +16,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
@ -400,7 +400,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
// Generate a list of nested if/else statements to handle rank
// specializations from 1-6.
// specializations from 1 to `kMaxRankSpecialization`.
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
rewriter, op, greater_rank, 1);
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
@ -419,13 +419,13 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
}
// Fire an assertion if none of the rank specializations applied (one of
// the ranks was greater than 6).
// the ranks was greater than `kMaxRankSpecialization`).
else_builder.create<AssertOp>(
loc,
GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
kMaxRankSpecialization),
"Input for dynamic binary op lowering was of a rank greater than "
"6");
"Input for dynamic binary op lowering was of a rank greater than " +
std::to_string(kMaxRankSpecialization));
// Add the rank 6 specialization to the innermost else block.
createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs,
kMaxRankSpecialization);

View File

@ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
alpha = 0.5,
alpha_real = 0.5,
alpha_imag = 0.0,
batch_size = 1,
algorithm = 0}
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
@ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
alpha = 0.5,
alpha_real = 0.5,
alpha_imag = 0.0,
beta = 1.0,
batch_size = 1,
algorithm = 0}

View File

@ -490,6 +490,7 @@ cc_library(
],
hdrs = [
"transforms/passes.h",
"transforms/prepare_quantize_lstm.h",
],
deps = [
"convert_type",

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <cstdarg>
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/lite/core/api/error_reporter.h"
namespace tflite {

View File

@ -46,10 +46,9 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <string>
#include <unordered_set>
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
namespace tflite {

View File

@ -50,11 +50,10 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
@ -510,6 +509,12 @@ Operation* BuildVariableOp(const tflite::TensorT& tensor,
return op.getOperation();
}
auto op = builder.create<tfl::ConstOp>(loc, value);
if (!tensor.quantization->min.empty()) {
if (auto stats_op =
ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
return stats_op;
}
}
return op.getOperation();
}

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
#include "absl/strings/string_view.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
namespace tflite {
// Converts a TFLite flatbuffer stored in `buffer` to a MLIR module

View File

@ -22,10 +22,9 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
namespace tflite {

View File

@ -494,6 +494,8 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
auto result_shape_type = result_type.cast<ShapedType>();
if (!result_shape_type.hasStaticShape()) return {};
if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
SmallVector<APFloat, 16> new_values;
const int num_elements = result_shape_type.getNumElements();
@ -1740,7 +1742,8 @@ static LogicalResult Verify(LSTMOp op) {
op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
// If this lstm has layer normalization, this input value,
// "forget_layer_norm_coefficients" should be a 1D tensor.
if (forget_layer_norm_coefficients.getRank() != 1 ||
if (!forget_layer_norm_coefficients.hasRank() ||
forget_layer_norm_coefficients.getRank() != 1 ||
forget_layer_norm_coefficients.getDimSize(0) != n_cell)
return op.emitOpError(
"coefficient inputs have more than 2 dimensions or "

View File

@ -31,10 +31,9 @@ limitations under the License.
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"

View File

@ -20,8 +20,8 @@ limitations under the License.
#include "llvm/ADT/None.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project

View File

@ -20,8 +20,8 @@ limitations under the License.
#include "llvm/ADT/None.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project

View File

@ -19,8 +19,8 @@ limitations under the License.
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project

View File

@ -20,7 +20,7 @@ limitations under the License.
#include <utility>
#include "llvm/ADT/Optional.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/core/public/session.h"

View File

@ -17,9 +17,9 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"

View File

@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
namespace mlir {
namespace TFL {

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project

View File

@ -32,7 +32,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {

View File

@ -17,9 +17,9 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"

View File

@ -283,6 +283,28 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}
func @testAvoidDilatedConvWithExpand(%arg0: tensor<*xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%cst_1 = constant dense<4> : tensor<2x2xi32>
%cst_2 = constant dense<0> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<*xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %5 : tensor<1x128x128xf32>
// CHECK-LABEL: testAvoidDilatedConvWithExpand
// CHECK: "tf.SpaceToBatchND"
// CHECK: "tf.ExpandDims"
// CHECK: "tf.Conv2D"
// CHECK: "tf.Squeeze"
// CHECK: "tf.BatchToSpaceND"
// CHECK: "tf.BiasAdd"
}
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>

View File

@ -1,10 +1,14 @@
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
// CHECK: effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>
// CHECK: input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>
// CHECK: input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>
// CHECK: input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01:3.200000e+01>>>
// CHECK: input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>
// CHECK-DAG: %[[input_18:.*]] = "quant.stats"({{.*}}) {layerStats = dense<[-8.000000e-01, 1.600000e+00]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
// CHECK-DAG: %[[input_19:.*]] = "quant.stats"({{.*}}) {layerStats = dense<[-2.000000e+00, 4.000000e+00]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK: "tfl.unidirectional_sequence_lstm"({{.*}}, %[[input_18]], %[[input_19]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}})
// CHECK-SAME: effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>
// CHECK-SAME: input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>
// CHECK-SAME: input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>
// CHECK-SAME: input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01:3.200000e+01>>>
// CHECK-SAME: input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>
{
"version": 3,
@ -110,8 +114,8 @@
"name": "input_activation_state18",
"is_variable": true,
"quantization": {
"min": [-0.9],
"max": [0.9]
"min": [-0.8],
"max": [1.6]
}
},
{
@ -119,8 +123,8 @@
"name": "input_cell_state19",
"is_variable": true,
"quantization": {
"min": [-0.8],
"max": [0.8]
"min": [-2.0],
"max": [4.0]
}
},
{

View File

@ -810,6 +810,16 @@ func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4
return %24 : tensor<1x4xf32>
}
// -----
// Coefficient inputs of LSTM op have unknown rank.
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<*xf32>) -> tensor<1x4xf32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
// expected-error @+1 {{'tfl.lstm' op coefficient inputs have more than 2 dimensions or don't match the dimension with input operand `input_to_output_weights`.}}
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<*xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
}
// -----

View File

@ -343,6 +343,18 @@ func @fuseAddIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf
// CHECK-NEXT: return %[[fc]] : tensor<4x2xf32>
}
// CHECK-LABEL: @doNotFuseAddIntoFollowingFullyConnected
func @doNotFuseAddIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>, %arg1: tensor<*xf32>) -> tensor<4x2xf32> {
%cst1 = constant dense<1.5> : tensor<f32>
%0 = "tfl.add"(%arg0, %cst1) {fused_activation_function = "NONE"} : (tensor<4x2xf32>, tensor<f32>) -> tensor<4x2xf32>
%cst = constant dense<2.0> : tensor<2xf32>
%1 = "tfl.fully_connected"(%0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<*xf32>, tensor<2xf32>) -> tensor<4x2xf32>
return %1 : tensor<4x2xf32>
// CHECK: "tfl.add"
// CHECK: "tfl.fully_connected"
}
// CHECK-LABEL: @fuseMulIntoFollowingFullyConnected
func @fuseMulIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
%cst2 = constant dense<1.5> : tensor<f32>

View File

@ -760,4 +760,13 @@ func @depthwise_conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x
// CHECK: "tf.DepthwiseConv2dNative"
}
// CHECK-LABEL: strided_slice_unranked_input
func @strided_slice_unranked_input(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
%18 = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
%57 = "tf.Const"() {value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
%534 = "tf.StridedSlice"(%arg0, %57, %57, %18) {begin_mask = 11 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 11 : i64, new_axis_mask = 4 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<*xf32>
return %534 : tensor<*xf32>
// CHECK: "tf.StridedSlice"
}
}

View File

@ -165,3 +165,28 @@ func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x
// CHECK-SAME: [[VAL_91]], [[VAL_52]] : tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>
// CHECK: }
// CHECK: }
// -----
// CHECK-LABEL: func @whileDifferentResultShapes
func @whileDifferentResultShapes(%arg0: tensor<i32>) -> tensor<?xf32>
attributes {tf.entry_function = {outputs = "result"}} {
%cst0 = constant dense<5> : tensor<i32> loc("N")
%cst1 = constant dense<3.0> : tensor<1xf32> loc("val")
%0:2 = "tfl.while"(%cst0, %cst1) ( {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
%cst_0 = constant dense<0> : tensor<i32>
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
"tfl.yield"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
%1 = "tfl.sub"(%arg2, %arg0) {fused_activation_function = "NONE"} :
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
"tfl.yield"(%1, %2) : (tensor<*xi32>, tensor<*xf32>) -> ()
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<?xf32>) loc("WhileOp")
// CHECK: (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<?xf32>, tensor<i32>)
return %0#1 : tensor<?xf32>
}

View File

@ -17,8 +17,7 @@ limitations under the License.
#include "llvm/ADT/Optional.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project

View File

@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
#include "llvm/ADT/Optional.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/core/public/session.h"

View File

@ -24,10 +24,9 @@ limitations under the License.
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/AsmState.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project

View File

@ -20,8 +20,8 @@ limitations under the License.
#include "absl/types/span.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"

View File

@ -175,6 +175,9 @@ InspectResult InspectWeight(
} else if (auto cst = dyn_cast<QConstOp>(inst)) {
attr = cst.value();
type = cst.getType().cast<ShapedType>();
} else {
result.can_compress = false;
return result;
}
// Currently we only support compressing weights of ops:
@ -222,6 +225,8 @@ std::vector<T> BuildSparsityParameterAttribute(
} else if (auto cst = dyn_cast<QConstOp>(inst)) {
attr = cst.value();
type = cst.getType().cast<ShapedType>();
} else {
assert(false && "Expected a constant-like op");
}
const int dims_count = type.getRank();
std::vector<int> shape(dims_count);

View File

@ -185,7 +185,12 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
stb_op.block_shape(), bts_op.block_shape(), rewriter);
if (!dilations_attr.hasValue()) return failure();
op.setAttr("dilations", dilations_attr.getValue());
if (expand_op) {
if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
return failure();
}
}
// TODO(b/149936532): Check that the input width & height are multiples of
// dilation rate.
@ -234,6 +239,9 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
}
}
// Set dilations
op.setAttr("dilations", dilations_attr.getValue());
if (expand_op) {
// If there is `expand_op`, we need to rewire the inputs to bypass the
// `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning

View File

@ -35,10 +35,9 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project

View File

@ -351,10 +351,10 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
// to properly broadcast the scalar to `{num_channels}` shape.
// Get the number of channels if possible.
auto filter_type = filter.getType().cast<ShapedType>();
auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
// Filter must be a `2D` tensor with `{num_channels, num_features}`
// shape. The following check is rejecting unknown rank (-1).
if (filter_type.getRank() != 2) {
if (filter_type == nullptr || filter_type.getRank() != 2) {
return failure();
}
int num_channels = filter_type.getShape()[0];

View File

@ -19,8 +19,8 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project

View File

@ -26,11 +26,10 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
// This transformation pass applies quantization propagation on TFLite dialect.
#include <cmath>
#include <iterator>
#include <string>
@ -30,7 +29,7 @@ limitations under the License.
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
@ -44,9 +43,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/tools/optimize/operator_property.h"
// NOLINTNEXTLINE
static llvm::cl::list<std::string> quantize_allowlist(
@ -61,6 +59,12 @@ static llvm::cl::opt<bool> quantize_signed(
llvm::cl::desc("signed inference type. Only used in tests"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> post_training_quantize(
"tfl-test-post-training-quantize", llvm::cl::value_desc("bool"),
llvm::cl::desc("enable post training quantization. Only used in tests"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> disable_per_channel(
"tfl-disable-per-channel", llvm::cl::value_desc("bool"),
@ -91,10 +95,9 @@ class PrepareQuantizePass
// Constructor used by the PassRegistration and enforce uint8 quantization.
// This is only used by test.
explicit PrepareQuantizePass() {
if (quantize_signed)
quant_specs_.inference_type = tensorflow::DT_QINT8;
else
quant_specs_.inference_type = tensorflow::DT_QUINT8;
quant_specs_.inference_type =
quantize_signed ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
quant_specs_.post_training_quantization = post_training_quantize;
}
// Constructor used by manually creating the pass.
@ -318,148 +321,9 @@ bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
using PrepareQuantStats =
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
// Calculates the minimum power of two that is not less than the value.
double power_of_two_bound(double value) {
return std::pow(2, std::ceil(std::log2(value)));
}
// Quantize recurrent input of LSTM with 16 bits.
template <typename SourceOp, typename Q, typename DQ>
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
public:
explicit ConvertLstmStatsToQDQs(MLIRContext* context)
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override {
tflite::optimize::operator_property::OpVariant lstm_variant;
if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
lstm_variant.op_code = tflite::BuiltinOperator_LSTM;
} else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(
op.getOperation())) {
lstm_variant.op_code =
tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
} else {
op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
return failure();
}
lstm_variant.use_projection =
!op.projection_weights().getType().template isa<NoneType>();
lstm_variant.use_peephole =
!op.cell_to_output_weights().getType().template isa<NoneType>();
lstm_variant.use_peephole =
!op.cell_to_output_weights().getType().template isa<NoneType>();
lstm_variant.use_layer_norm =
!op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
auto lstm_property =
tflite::optimize::operator_property::GetOperatorProperty(lstm_variant);
// Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
const std::vector<std::string> intermediate_attributes = {
"input_to_input_intermediate", "input_to_forget_intermediate",
"input_to_cell_intermediate", "input_to_output_intermediate",
"effective_hidden_scale_intermediate"};
for (auto& enumerated_intermediates : lstm_property.intermediates) {
int index = enumerated_intermediates.first;
auto& tensor_property = enumerated_intermediates.second;
// intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
if (!lstm_variant.use_layer_norm && index != 4) {
continue;
}
// intermediate tensor 4 is only used with projection.
if (!lstm_variant.use_projection && index == 4) {
continue;
}
TypeAttr attr =
op.template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
if (!attr) {
op.emitError()
<< op.getOperationName()
<< " requires quantization values for intermediate tensor "
<< intermediate_attributes[index];
return failure();
}
auto quantized_type =
QuantizedType::getQuantizedElementType(attr.getValue());
if (!quantized_type) {
op.emitError() << intermediate_attributes[index]
<< " is not quantized.";
return failure();
}
auto calibrated_type =
quantized_type.dyn_cast<quant::CalibratedQuantizedType>();
if (!calibrated_type) {
int num_storage_bits = quantized_type.getStorageTypeIntegralWidth();
if (tensor_property.number_of_bits != num_storage_bits) {
op.emitError() << intermediate_attributes[index]
<< " is expected to be quantized with "
<< tensor_property.number_of_bits << " bits, but got "
<< num_storage_bits << " bits instead.";
return failure();
}
continue; // skip if it is already quantized.
}
quant::UniformQuantizedType qtype;
if (tensor_property.number_of_bits == 8) {
qtype = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits,
calibrated_type.getMin(), calibrated_type.getMax(),
/*narrowRange=*/false, calibrated_type.getExpressedType(),
/*isSigned=*/false);
} else if (tensor_property.number_of_bits == 16) {
double max = std::max(std::abs(calibrated_type.getMin()),
std::abs(calibrated_type.getMax()));
qtype = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits, -max, max,
/*narrowRange=*/true, calibrated_type.getExpressedType(),
/*isSigned=*/true);
} else {
op.emitError() << "Unsupported quantization bits: "
<< tensor_property.number_of_bits;
return failure();
}
op.setAttr(intermediate_attributes[index],
TypeAttr::get(qtype.castFromExpressedType(
qtype.castToExpressedType(attr.getValue()))));
}
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
op.input_cell_state().getDefiningOp());
// Recurrent input is be used within an LSTM, and thus should have one use.
if (!stats_op || !stats_op.getResult().hasOneUse()) {
return failure();
}
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
if (!stats) {
return failure();
}
double max = std::max(
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}))),
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
double bound = power_of_two_bound(max);
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
// Set flags to 1 for signed type.
quant::QuantizedType quant_type = UniformQuantizedType::getChecked(
quant::QuantizationFlags::Signed,
IntegerType::get(16, expressed.getContext()), expressed,
/*scale=*/bound / 32768.0, /*zeroPoint=*/0, llvm::minIntN(16),
llvm::maxIntN(16), op.getLoc());
rewriter.setInsertionPointAfter(stats_op);
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
return success();
}
};
using PrepareLstmQuantStats =
ConvertLstmStatsToQDQs<TFL::UnidirectionalSequenceLSTMOp,
quant::QuantizeCastOp, quant::DequantizeCastOp>;
TFL::ConvertLstmStatsToQDQs<TFL::UnidirectionalSequenceLSTMOp,
quant::QuantizeCastOp, quant::DequantizeCastOp>;
void PrepareQuantizePass::runOnFunction() {
FuncOp func = getFunction();
@ -503,7 +367,7 @@ void PrepareQuantizePass::runOnFunction() {
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
}
patterns.insert<PrepareLstmQuantStats>(ctx);
patterns.insert<PrepareLstmQuantStats>(ctx, quant_specs_);
applyPatternsAndFoldGreedily(func, std::move(patterns));
SanityCheckAndAdjustment(func);

View File

@ -0,0 +1,199 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Transform pass for LSTMs.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM
#include <algorithm>
#include <cmath>
#include <string>
#include <vector>
#include "llvm/Support/Casting.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/tools/optimize/operator_property.h"
//===----------------------------------------------------------------------===//
// The prepare-quantize Pass for LSTM.
//
namespace mlir {
namespace TFL {
// Calculates the minimum power of two that is not less than the value.
inline double power_of_two_bound(double value) {
return std::pow(2, std::ceil(std::log2(value)));
}
namespace operator_property = ::tflite::optimize::operator_property;
// Quantize recurrent input of LSTM with 16 bits.
template <typename SourceOp, typename Q, typename DQ>
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
public:
ConvertLstmStatsToQDQs(MLIRContext* context,
const QuantizationSpecs& quant_specs)
: OpRewritePattern<SourceOp>(context, /*benefit=*/2),
quant_specs(quant_specs) {}
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override {
operator_property::OpVariant lstm_variant;
if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
lstm_variant.op_code = tflite::BuiltinOperator_LSTM;
} else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(
op.getOperation())) {
lstm_variant.op_code =
tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
} else {
op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
return failure();
}
lstm_variant.use_projection =
!op.projection_weights().getType().template isa<NoneType>();
lstm_variant.use_peephole =
!op.cell_to_output_weights().getType().template isa<NoneType>();
lstm_variant.use_peephole =
!op.cell_to_output_weights().getType().template isa<NoneType>();
lstm_variant.use_layer_norm =
!op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
auto lstm_property = operator_property::GetOperatorProperty(lstm_variant);
// Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
const std::vector<std::string> intermediate_attributes = {
"input_to_input_intermediate", "input_to_forget_intermediate",
"input_to_cell_intermediate", "input_to_output_intermediate",
"effective_hidden_scale_intermediate"};
for (auto& enumerated_intermediates : lstm_property.intermediates) {
int index = enumerated_intermediates.first;
auto& tensor_property = enumerated_intermediates.second;
// intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
if (!lstm_variant.use_layer_norm && index != 4) {
continue;
}
// intermediate tensor 4 is only used with projection.
if (!lstm_variant.use_projection && index == 4) {
continue;
}
TypeAttr attr =
op.template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
if (!attr) {
op.emitError()
<< op.getOperationName()
<< " requires quantization values for intermediate tensor "
<< intermediate_attributes[index];
return failure();
}
auto quantized_type =
QuantizedType::getQuantizedElementType(attr.getValue());
if (!quantized_type) {
op.emitError() << intermediate_attributes[index]
<< " is not quantized.";
return failure();
}
auto calibrated_type =
quantized_type.dyn_cast<quant::CalibratedQuantizedType>();
if (!calibrated_type) {
int num_storage_bits = quantized_type.getStorageTypeIntegralWidth();
if (tensor_property.number_of_bits != num_storage_bits) {
op.emitError() << intermediate_attributes[index]
<< " is expected to be quantized with "
<< tensor_property.number_of_bits << " bits, but got "
<< num_storage_bits << " bits instead.";
return failure();
}
continue; // skip if it is already quantized.
}
quant::UniformQuantizedType qtype;
if (tensor_property.number_of_bits == 8) {
qtype = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits,
calibrated_type.getMin(), calibrated_type.getMax(),
/*narrowRange=*/false, calibrated_type.getExpressedType(),
/*isSigned=*/quant_specs.IsSignedInferenceType());
} else if (tensor_property.number_of_bits == 16) {
double max = std::max(std::abs(calibrated_type.getMin()),
std::abs(calibrated_type.getMax()));
qtype = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits, -max, max,
/*narrowRange=*/true, calibrated_type.getExpressedType(),
/*isSigned=*/true);
} else {
op.emitError() << "Unsupported quantization bits: "
<< tensor_property.number_of_bits;
return failure();
}
op.setAttr(intermediate_attributes[index],
TypeAttr::get(qtype.castFromExpressedType(
qtype.castToExpressedType(attr.getValue()))));
}
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
op.input_cell_state().getDefiningOp());
// Recurrent input is be used within an LSTM, and thus should have one use.
if (!stats_op || !stats_op.getResult().hasOneUse()) {
return failure();
}
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
if (!stats) {
return failure();
}
double max = std::max(
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}))),
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
double bound = power_of_two_bound(max);
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
// Set flags to 1 for signed type.
quant::QuantizedType quant_type = UniformQuantizedType::getChecked(
quant::QuantizationFlags::Signed,
IntegerType::get(16, expressed.getContext()), expressed,
/*scale=*/bound / 32768.0, /*zeroPoint=*/0, llvm::minIntN(16),
llvm::maxIntN(16), op.getLoc());
rewriter.setInsertionPointAfter(stats_op);
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
return success();
}
private:
QuantizationSpecs quant_specs;
};
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM

View File

@ -44,7 +44,7 @@ limitations under the License.
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
@ -526,9 +526,12 @@ struct ConvertTFStridedSlice : public RewritePattern {
// Insert a new reshape op.
Value original_input = strided_slice_op.input();
// TODO(b/174267775): Make sure that the input type has ranked tensor type.
RankedTensorType original_input_type =
original_input.getType().cast<RankedTensorType>();
original_input.getType().dyn_cast<RankedTensorType>();
if (!original_input_type) {
return failure();
}
const ArrayRef<int64_t> &original_input_shape =
original_input_type.getShape();
SmallVector<int64_t, 4> revised_shape;

View File

@ -22,9 +22,9 @@ limitations under the License.
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project

View File

@ -22,10 +22,9 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project

View File

@ -211,24 +211,25 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
// change, so replace with new while op.
if (extra_operands.empty()) return;
Operation* op = while_op.getOperation();
const int operands_size = while_op.getNumOperands() + extra_operands.size();
SmallVector<Value, 4> operands;
operands.reserve(operands_size);
operands.append(while_op.getOperands().begin(), while_op.getOperands().end());
operands.append(extra_operands.begin(), extra_operands.end());
SmallVector<Type, 4> new_types;
operands.reserve(types.size());
new_types.reserve(operands.size());
auto add_operand = [&](Value v) {
operands.push_back(v);
new_types.push_back(v.getType());
};
for (auto operand : op->getOperands()) add_operand(operand);
for (auto operand : extra_operands) add_operand(operand);
new_types.reserve(operands_size);
new_types.append(while_op.getResultTypes().begin(),
while_op.getResultTypes().end());
for (auto extra_operand : extra_operands)
new_types.push_back(extra_operand.getType());
Operation* new_op = OpBuilder(op).insert(Operation::create(
op->getLoc(), op->getName(), new_types, operands, op->getAttrs(),
/*successors=*/{}, /*numRegions=*/2));
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
op->replaceAllUsesWith(new_op->getResults().take_front(op->getNumResults()));
op->erase();
auto new_while_op = OpBuilder(while_op).create<WhileOp>(
while_op.getLoc(), new_types, operands, while_op.getAttrs());
new_while_op.cond().takeBody(while_op.cond());
new_while_op.body().takeBody(while_op.body());
while_op.replaceAllUsesWith(
new_while_op.getResults().take_front(while_op.getNumResults()));
while_op.erase();
}
void WhileOutlinePass::runOnOperation() {

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <functional>
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"

View File

@ -16,7 +16,7 @@ limitations under the License.
#include <string>
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/InitAllPasses.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

View File

@ -44,6 +44,7 @@ filegroup(
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
],
@ -285,8 +286,10 @@ gentbl(
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tf_device_ops.td",
td_srcs = [
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
"@llvm-project//mlir:SideEffectTdFiles",
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
],
test = True,
)
@ -408,6 +411,7 @@ cc_library(
":tensorflow_traits",
":tensorflow_types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
@ -453,6 +457,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
@ -495,6 +500,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
@ -533,6 +539,7 @@ cc_library(
":tensorflow_remaining_ops",
":tensorflow_tfrt_ops",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
@ -640,6 +647,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CallOpInterfacesIncGen",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
@ -815,6 +823,7 @@ cc_library(
deps = [
":tensorflow",
":tensorflow_op_interfaces",
":tensorflow_side_effects",
":tensorflow_types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
@ -856,6 +865,7 @@ cc_library(
"transforms/cluster_outlining.cc",
"transforms/cluster_tf_ops_pass.cc",
"transforms/collection_ops_util.cc",
"transforms/constant_op_device_assignment.cc",
"transforms/contraction_fusion.cc",
"transforms/decompose_resource_ops_pass.cc",
"transforms/device_index_selector.cc",
@ -882,7 +892,6 @@ cc_library(
"transforms/optimize.cc",
"transforms/outside_compiled_to_host_launch.cc",
"transforms/parallel_execute_to_islands.cc",
"transforms/parallelize_embedding_params_ops_pass.cc",
"transforms/promote_resources_to_args.cc",
"transforms/readonly_references_to_resources.cc",
"transforms/region_control_flow_to_functional.cc",
@ -995,6 +1004,7 @@ cc_library(
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Rewrite",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",

View File

@ -21,8 +21,7 @@ limitations under the License.
#include <memory>
#include "llvm/ADT/DenseMap.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "mlir/Analysis/CallGraph.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project

View File

@ -30,8 +30,8 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
@ -94,20 +95,37 @@ struct SideEffects {
using ResourceSideEffectsByValue = llvm::SmallDenseMap<Value, SideEffects>;
bool MustExecute(const MemoryEffects::EffectInstance& effect) {
if (llvm::isa<ResourceEffects::TPUEmbedding>(effect.getResource())) {
assert(!effect.getValue() && !effect.getParameters() &&
isa<MemoryEffects::Write>(effect.getEffect()));
return true;
}
return false;
}
// Collects memory side effects for an operation by value (operands and
// results).
ResourceSideEffectsByValue GetResourceInfoForOp(Operation* op) {
ResourceSideEffectsByValue resource_info;
void GetResourceInfoForOp(Operation* op,
ResourceSideEffectsByValue& resource_info,
bool& must_execute) {
auto interface = dyn_cast<MemoryEffectOpInterface>(op);
if (!interface) return resource_info;
if (!interface) return;
llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
interface.getEffects(effects);
for (auto& effect : effects) {
if (MustExecute(effect)) {
must_execute = true;
continue;
}
// TODO(lyandy): Support effects with no value defined.
if (!effect.getValue()) return ResourceSideEffectsByValue();
if (!effect.getValue()) {
resource_info.clear();
must_execute = false;
return;
}
auto it = resource_info.try_emplace(effect.getValue());
auto& side_effect = it.first->getSecond();
auto* resource_effect = effect.getEffect();
@ -120,11 +138,11 @@ ResourceSideEffectsByValue GetResourceInfoForOp(Operation* op) {
} else if (isa<MemoryEffects::Write>(resource_effect)) {
side_effect.write = true;
} else {
return ResourceSideEffectsByValue();
resource_info.clear();
must_execute = false;
return;
}
}
return resource_info;
}
// Checks if a value is a result of `op`.
@ -328,6 +346,7 @@ void SideEffectAnalysisInfo::AnalyzeRegion(
// We explicitly iterates through the regions and blocks, in order to handle
// different nested regions separately.
for (auto& block : *region) {
llvm::SmallPtrSet<Operation*, 8> non_resource_control_predecessors;
for (auto& op : block) {
for (Region& child : op.getRegions()) {
SideEffectAnalysisInfo child_analysis(&child, alias_analysis);
@ -340,10 +359,21 @@ void SideEffectAnalysisInfo::AnalyzeRegion(
// We do not need explicit control edges for declaration ops.
if (OpIsDeclaration(&op, alias_analysis)) continue;
auto resource_op_info = GetResourceInfoForOp(&op);
ResourceSideEffectsByValue resource_op_info;
bool must_execute = false;
GetResourceInfoForOp(&op, resource_op_info, must_execute);
if (resource_op_info.empty() && OpIsKnownToHaveNoSideEffect(&op))
continue;
if (resource_op_info.empty() && must_execute) {
// Add unknown resource ops as predecessors of the op that must execute,
// to guarantee ordering between unknown resource ops.
AddPredecessorsForAccess(kUnknownResourceId, &op, /*read_only=*/false);
non_resource_control_predecessors.insert(&op);
continue;
}
if (IsResourceOpAllocOnly(&op, resource_op_info)) continue;
auto resource_ids_by_value =
@ -399,6 +429,14 @@ void SideEffectAnalysisInfo::AnalyzeRegion(
if (!resource_ids_by_value.hasValue()) {
// Update access info for unknown resource.
TrackAccess(kUnknownResourceId, &op, read_only);
// Add ops that must execute to unknown resource op predecessors.
auto& control_predecessors = control_predecessors_[&op];
control_predecessors.insert(non_resource_control_predecessors.begin(),
non_resource_control_predecessors.end());
// Ops that must execute currently tracked are cleared as transitively
// unknown resource ops will allow for such ops to be transitively
// reachable.
non_resource_control_predecessors.clear();
}
}
}

View File

@ -22,10 +22,9 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project

View File

@ -20,10 +20,12 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
namespace mlir {
namespace tf_device {

View File

@ -19,6 +19,8 @@ limitations under the License.
#define TF_DEVICE_DIALECT
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
// TensorFlow Device Dialect definitions
@ -85,7 +87,7 @@ This op captures all needed live-in values.
let hasCanonicalizer = 1;
}
def TfDevice_ReturnOp : TfDevice_Op<"return", [Terminator]> {
def TfDevice_ReturnOp : TfDevice_Op<"return", [NoSideEffect, ReturnLike, Terminator]> {
let summary = [{
The `tf_device.return` operation terminates and returns values from a
`tf_device` dialect operation.

View File

@ -30,8 +30,8 @@ limitations under the License.
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project

File diff suppressed because it is too large Load Diff

View File

@ -140,6 +140,7 @@ def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">;
def TF_VariableRead : MemRead<TF_VariableResource>;
def TF_StackRead : MemRead<TF_StackResource>;
@ -174,6 +175,8 @@ def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//

View File

@ -82,9 +82,9 @@ struct ResourceHandle {
// Make ResourceHandle hashable.
friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle);
std::string container;
std::string name;
std::string device;
StringRef container;
StringRef name;
StringRef device;
Operation* op = nullptr;
};

View File

@ -41,9 +41,9 @@ limitations under the License.
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project

View File

@ -22,14 +22,14 @@ limitations under the License.
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project

View File

@ -30,8 +30,10 @@ limitations under the License.
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
@ -352,7 +354,8 @@ else_branch: A function that takes 'inputs' and returns a list of
}
def TF_YieldOp : TF_Op<"Yield",
[Terminator, ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> {
[NoSideEffect, ReturnLike, Terminator,
ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> {
let summary = "Yield operation";
let description = [{

View File

@ -43,9 +43,9 @@ limitations under the License.
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project

View File

@ -19,14 +19,14 @@ limitations under the License.
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project

View File

@ -44,9 +44,9 @@ limitations under the License.
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
@ -1947,6 +1947,44 @@ ResourceHandleValueAndId SummaryWriterOp::GetResourceHandleValueAndId(
next_id);
}
//===----------------------------------------------------------------------===//
// TPUExecuteAndUpdateVariablesOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TPUExecuteAndUpdateVariablesOp op) {
int num_resource_args = 0;
for (Type arg_type : op.args().getTypes())
if (arg_type.cast<TensorType>().getElementType().isa<ResourceType>())
++num_resource_args;
auto check_attr = [&](ArrayAttr indices, llvm::StringRef name,
int min) -> LogicalResult {
if (indices.size() != num_resource_args)
return op.emitOpError()
<< "requires '" << name
<< "' to be the same size as number of resource handles in 'args' "
"("
<< num_resource_args << "), but got " << indices.size();
for (auto entry : llvm::enumerate(indices.getValue())) {
auto int_attr = entry.value().cast<IntegerAttr>();
if (int_attr.getInt() < min)
return op.emitOpError()
<< "requires '" << name << "' to contain values of at least "
<< min << ", but got " << int_attr.getInt() << " at index "
<< entry.index();
}
return success();
};
return failure(
failed(check_attr(op.device_var_reads_indices(),
/*name=*/"device_var_reads_indices", /*min=*/0)) ||
failed(check_attr(op.device_var_updates_indices(),
/*name=*/"device_var_updates_indices", /*min=*/-1)));
}
//===----------------------------------------------------------------------===//
// TensorListReserveOp
//===----------------------------------------------------------------------===//
@ -2128,7 +2166,8 @@ class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
// If the input is an unranked tensor, cannpt rewrite.
if (!type) return failure();
// Expected return type of the ToBool operation.
// Expected return type of the ToBool operation. The return type of ToBool
// operation is always 0D tensor of bool type.
auto result_type = op.getResult().getType().cast<RankedTensorType>();
// If input is already a tensor<i1>, it can be folded into an identity.
@ -2533,21 +2572,20 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
TypeRange body_input,
TypeRange body_result) {
// Collect all the type lists for the op so that different pairs of type lists
// can be compared for the compatibility.
constexpr int kNumTypeLists = 5;
const std::array<TypeRangeWithDesc, kNumTypeLists> type_lists = {{
{op->getOperandTypes(), "input"},
TypeRange body_result,
bool shape_invariant) {
const TypeRangeWithDesc input_type = {op->getOperandTypes(), "input"};
const TypeRangeWithDesc result_type = {op->getResultTypes(), "result"};
constexpr int kNumRegionTypeLists = 3;
const std::array<TypeRangeWithDesc, kNumRegionTypeLists> region_types = {{
{body_result, "body result"},
{op->getResultTypes(), "result"},
{cond_input, "condition input"},
{body_input, "body input"},
}};
// A pair of type lists should be cast compatible with each other if one is
// converted to the another for a function call or assignment or there is a
// common source of inputs for both. Therefore, the While op requires the
// common source of inputs for both. Therefore, the While op requires the
// following pairs of type lists to be cast compatible for the tensor_cast
// operation:
//
@ -2556,7 +2594,8 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
// * Operands and body inputs to call the body function for the first
// iteration if the cond functions returns True or equivalent result.
// * Operands and results to assign cond function arguments to op results if
// the cond function returns False or equivalent result.
// the cond function returns False or equivalent result. If the op is shape
// invariant, this does not hold as shapes can differ.
// * All three pairs using cond inputs, body inputs and results as operand is
// a common source for all three.
// * Body result and cond inputs to call the cond function for the subsequent
@ -2565,17 +2604,28 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
//
// Note that the operands and body results need not be compatible as they are
// never converted from one to the another nor there is a common source
// tensors. Compatibility requirement is not transitive.
// tensors. Compatibility requirement is not transitive.
if (!shape_invariant &&
failed(VerifyTypeRangesAreCompatible(op, input_type, result_type)))
return failure();
// Skip the first pair as the While op operands and body function results does
// not need to be compatible with each other.
for (int i = 1; i < kNumRegionTypeLists; ++i)
if (failed(VerifyTypeRangesAreCompatible(op, input_type, region_types[i])))
return failure();
for (int i = 0; i < kNumRegionTypeLists; ++i)
if (failed(VerifyTypeRangesAreCompatible(op, result_type, region_types[i])))
return failure();
for (int i = 0; i < kNumRegionTypeLists; ++i)
for (int j = i + 1; j < kNumRegionTypeLists; ++j)
if (failed(VerifyTypeRangesAreCompatible(op, region_types[i],
region_types[j])))
return failure();
for (int i = 0; i < kNumTypeLists; ++i) {
// Skip the first pair as the While op operands and body function results
// does not need to be compatible with each other.
for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) {
auto &a = type_lists[i];
auto &b = type_lists[j];
if (failed(VerifyTypeRangesAreCompatible(op, a, b))) return failure();
}
}
return success();
}
@ -2600,7 +2650,8 @@ static LogicalResult Verify(WhileOp op) {
if (failed(VerifyWhileTypes(op, /*cond_input=*/cond_fn_type.getInputs(),
/*body_input=*/body_fn_type.getInputs(),
/*body_result=*/body_fn_type.getResults())))
/*body_result=*/body_fn_type.getResults(),
op.shape_invariant())))
return failure();
return success();
}
@ -2625,7 +2676,8 @@ static LogicalResult Verify(WhileRegionOp op) {
Operation *body_yield = op.body().front().getTerminator();
if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(),
/*body_input=*/op.body().getArgumentTypes(),
/*body_result=*/body_yield->getOperandTypes())))
/*body_result=*/body_yield->getOperandTypes(),
op.shape_invariant())))
return failure();
return success();
}

Some files were not shown because too many files have changed in this diff Show More