Merge branch 'master' into amankishore/language-fix
This commit is contained in:
commit
2e40a04526
README.md
tensorflow
BUILD
compiler/mlir
BUILD
hlo
BUILD
include/mlir-hlo/Dialect/mhlo/IR
CMakeLists.txthlo_ops.hhlo_ops.tdhlo_ops_base.tdhlo_ops_base_enums.hhlo_ops_base_enums.tdlhlo_gpu_ops.hlhlo_gpu_ops.tdlhlo_gpu_ops_enums.hlhlo_gpu_ops_enums.tdlhlo_gpu_ops_structs.tdlhlo_ops.td
lib/Dialect/mhlo
IR
transforms
tests
lite
BUILDemit_error_reporter.hflatbuffer_export.ccflatbuffer_export.hflatbuffer_import.ccflatbuffer_import.hflatbuffer_translate.ccflatbuffer_translate.h
mlir_graph_optimization_pass.hir
mlir_tflite_runner.ccpython
graphdef_to_tfl_flatbuffer.ccsaved_model_to_tfl_flatbuffer.cctf_tfl_flatbuffer_helpers.cctf_tfl_flatbuffer_helpers.h
quantization
lite
quantization_context.ccquantization_context.hquantization_driver.ccquantization_utils.htensorflow
sparsity
tests
tf_tfl_passes.cctf_tfl_passes.htf_tfl_translate.cctf_to_tfl_flatbuffer.cctf_to_tfl_flatbuffer.htransforms
dense_to_sparse.ccdilated_conv.hlower_static_tensor_list.ccoptimize.ccoptimize_functional_ops.ccprepare_composite_functions_tf.ccprepare_quantize.ccprepare_quantize_lstm.hprepare_tf.ccquantize.ccsplit_merged_operands.ccwhile_loop_outline.cc
utils
python
tensorflow
@ -155,6 +155,7 @@ Container Type | Status | Art
|
||||
* [DeepLearning.AI TensorFlow Developer Professional Certificate](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
|
||||
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
|
||||
* [Intro to TensorFlow for A.I, M.L, and D.L from Coursera](https://www.coursera.org/learn/introduction-tensorflow)
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
|
||||
|
@ -728,8 +728,8 @@ tf_cc_shared_object(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
"//tensorflow/c:kernels_hdrs",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
||||
"//tensorflow/c:kernels",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:ops_hdrs",
|
||||
"//tensorflow/cc/saved_model:loader_lite_impl",
|
||||
|
@ -112,6 +112,8 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
|
||||
"//tensorflow/compiler/mlir/tosa:tf_tosa_passes",
|
||||
"//tensorflow/compiler/mlir/tosa:tfl_tosa_passes",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -45,6 +45,7 @@ filegroup(
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
|
||||
@ -170,6 +171,21 @@ gentbl(
|
||||
td_srcs = [":hlo_ops_td_files"],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "hlo_ops_base_enums_inc_gen",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
tbl_outs = [
|
||||
("-gen-enum-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc"),
|
||||
("-gen-enum-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.cc.inc"),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td",
|
||||
td_relative_includes = [
|
||||
"include",
|
||||
],
|
||||
td_srcs = [":hlo_ops_td_files"],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "hlo_ops_pattern_gen",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
@ -230,6 +246,22 @@ gentbl(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "lhlo_gpu_ops_enums_inc_gen",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
strip_include_prefix = "include",
|
||||
tbl_outs = [
|
||||
("-gen-enum-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h.inc"),
|
||||
("-gen-enum-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc.inc"),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td",
|
||||
td_relative_includes = [
|
||||
"include",
|
||||
],
|
||||
td_srcs = [":hlo_ops_td_files"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lhlo_gpu_ops_structs",
|
||||
srcs = [
|
||||
@ -248,6 +280,23 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lhlo_gpu_ops_enums",
|
||||
srcs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc.inc",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h.inc",
|
||||
"lib/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":lhlo_gpu_ops_enums_inc_gen",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "lhlo_gpu_ops_inc_gen",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
@ -265,6 +314,7 @@ gentbl(
|
||||
":hlo_ops_td_files",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -342,6 +392,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_ops_base_enums",
|
||||
srcs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc",
|
||||
"lib/Dialect/mhlo/IR/hlo_ops_base_enums.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":hlo_ops_base_enums_inc_gen",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convert_op_folder",
|
||||
srcs = ["lib/utils/convert_op_folder.cc"],
|
||||
@ -374,6 +440,7 @@ cc_library(
|
||||
":canonicalize_inc_gen",
|
||||
":chlo_ops_inc_gen",
|
||||
":convert_op_folder",
|
||||
":hlo_ops_base_enums",
|
||||
":hlo_ops_base_inc_gen",
|
||||
":hlo_ops_base_structs",
|
||||
":hlo_ops_inc_gen",
|
||||
@ -405,6 +472,7 @@ cc_library(
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":hlo_ops_base_enums",
|
||||
":hlo_ops_base_inc_gen",
|
||||
":hlo_ops_base_structs",
|
||||
":lhlo_ops_inc_gen",
|
||||
@ -436,8 +504,10 @@ cc_library(
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_ops_base_enums",
|
||||
":hlo_ops_base_structs",
|
||||
":infer_fusibility_op_interface",
|
||||
":lhlo_gpu_ops_enums",
|
||||
":lhlo_gpu_ops_inc_gen",
|
||||
":lhlo_gpu_ops_structs",
|
||||
"@llvm-project//llvm:Support",
|
||||
|
@ -32,6 +32,8 @@ mlir_tablegen(hlo_ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(hlo_ops.cc.inc -gen-op-defs)
|
||||
mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls)
|
||||
mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs)
|
||||
mlir_tablegen(hlo_ops_base_enums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(hlo_ops_base_enums.cc.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRhlo_opsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td)
|
||||
@ -40,6 +42,9 @@ mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs)
|
||||
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td)
|
||||
mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls)
|
||||
mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs)
|
||||
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_enums.td)
|
||||
mlir_tablegen(lhlo_gpu_ops_enums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(lhlo_gpu_ops_enums.cc.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen)
|
||||
add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen)
|
||||
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
|
||||
// clang-format off
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||
// clang-format on
|
||||
|
||||
|
@ -41,7 +41,9 @@ def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">;
|
||||
def HLO_CUSTOM_FUSION : StrEnumAttrCase<"kCustom">;
|
||||
def HLO_FusionKindAttr : StrEnumAttr<"FusionKind", "fusion kind", [
|
||||
HLO_LOOP_FUSION, HLO_INPUT_FUSION, HLO_OUTPUT_FUSION, HLO_CUSTOM_FUSION
|
||||
]>;
|
||||
]> {
|
||||
let cppNamespace = "::mlir::mhlo";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MHLO nullary op definitions.
|
||||
@ -896,7 +898,7 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
||||
(ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs),
|
||||
ConvolutionAttributes<HLO_Dialect>.attributes);
|
||||
ConvolutionAttributes.attributes);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ def HLO_Dialect : Dialect {
|
||||
let cppNamespace = "::mlir::mhlo";
|
||||
}
|
||||
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td"
|
||||
|
||||
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
|
||||
@ -692,77 +693,7 @@ class BASE_HLO_TupleOp {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Precision Config enum definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// These mirror the XLA PrecisionConfig proto enum.
|
||||
def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">;
|
||||
def HLO_PRECISION_HIGH : StrEnumAttrCase<"HIGH">;
|
||||
def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">;
|
||||
|
||||
def HLO_PrecisionAttr : StrEnumAttr<"Precision",
|
||||
"XLA precision for an operand. Has backend specific meaning.",
|
||||
[HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]>;
|
||||
|
||||
// TODO(b/129153247) See if it's possible to also validate the size.
|
||||
def HLO_PrecisionConfigAttr:
|
||||
OptionalAttr<
|
||||
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fast Fourier Transform Type enum definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// These mirror the XLA FftType proto enum.
|
||||
def HLO_FFT_TYPE_FFT : StrEnumAttrCase<"FFT">;
|
||||
def HLO_FFT_TYPE_IFFT : StrEnumAttrCase<"IFFT">;
|
||||
def HLO_FFT_TYPE_RFFT : StrEnumAttrCase<"RFFT">;
|
||||
def HLO_FFT_TYPE_IRFFT : StrEnumAttrCase<"IRFFT">;
|
||||
|
||||
def HLO_FftTypeAttr : StrEnumAttr<"FftType",
|
||||
"XLA fast fourier transform type.",
|
||||
[HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT,
|
||||
HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Comparison op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// These mirror the XLA ComparisonDirection enum.
|
||||
def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">;
|
||||
def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">;
|
||||
def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">;
|
||||
def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">;
|
||||
def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">;
|
||||
def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">;
|
||||
|
||||
def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
|
||||
"Which comparison operation to perform.",
|
||||
[
|
||||
HLO_COMPARISON_DIRECTION_EQ,
|
||||
HLO_COMPARISON_DIRECTION_NE,
|
||||
HLO_COMPARISON_DIRECTION_GE,
|
||||
HLO_COMPARISON_DIRECTION_GT,
|
||||
HLO_COMPARISON_DIRECTION_LE,
|
||||
HLO_COMPARISON_DIRECTION_LT
|
||||
]>;
|
||||
|
||||
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
|
||||
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
|
||||
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
|
||||
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
|
||||
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
|
||||
|
||||
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
|
||||
"Which comparison type to use.",
|
||||
[
|
||||
HLO_COMPARISON_TYPE_FLOAT,
|
||||
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
|
||||
HLO_COMPARISON_TYPE_SIGNED,
|
||||
HLO_COMPARISON_TYPE_UNSIGNED
|
||||
]>;
|
||||
|
||||
|
||||
class BASE_HLO_CompareOp {
|
||||
@ -783,13 +714,6 @@ class BASE_HLO_CompareOp {
|
||||
// Quantize op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// These mirror the XLA ComparisonDirection enum.
|
||||
def HLO_MIN_COMBINED : StrEnumAttrCase<"MIN_COMBINED">;
|
||||
|
||||
def HLO_DequantizeModeAttr : StrEnumAttr<"DequantizeMode",
|
||||
"Dequantization mode. Only MIN_COMBINED is supported.",
|
||||
[HLO_MIN_COMBINED]>;
|
||||
|
||||
class BASE_HLO_DequantizeOp {
|
||||
string summary = "Dequantize operator";
|
||||
|
||||
@ -1029,7 +953,12 @@ class BASE_HLO_ConcatenateOp {
|
||||
// Common convolution attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class ConvolutionAttributes<Dialect dialect> {
|
||||
// TODO(b/129153247) See if it's possible to also validate the size.
|
||||
def HLO_PrecisionConfigAttr:
|
||||
OptionalAttr<
|
||||
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
|
||||
|
||||
def ConvolutionAttributes {
|
||||
dag attributes = (ins
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||
@ -1270,21 +1199,6 @@ class BASE_HLO_TransposeOp {
|
||||
}];
|
||||
}
|
||||
|
||||
// These mirror the XLA Transpose enum in Triangular Solve options.
|
||||
def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">;
|
||||
def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">;
|
||||
def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">;
|
||||
def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">;
|
||||
|
||||
def HLO_TransposeAttr : StrEnumAttr<"Transpose",
|
||||
"Transpose options",
|
||||
[
|
||||
HLO_TRANSPOSE_INVALID,
|
||||
HLO_NO_TRANSPOSE,
|
||||
HLO_TRANSPOSE,
|
||||
HLO_ADJOINT
|
||||
]>;
|
||||
|
||||
class BASE_HLO_TriangularSolveOp {
|
||||
string summary = "TriangularSolve operator";
|
||||
|
||||
|
@ -0,0 +1,29 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines enums used in MHLO and LMHLO.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
|
||||
// Order matters, this .inc header is not self-contained, and relies on the
|
||||
// #includes above.
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc"
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_
|
@ -0,0 +1,119 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef HLO_OPS_BASE_ENUMS
|
||||
#define HLO_OPS_BASE_ENUMS
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Precision Config enum definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// These mirror the XLA PrecisionConfig proto enum.
|
||||
def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">;
|
||||
def HLO_PRECISION_HIGH : StrEnumAttrCase<"HIGH">;
|
||||
def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">;
|
||||
|
||||
def HLO_PrecisionAttr : StrEnumAttr<"Precision",
|
||||
"XLA precision for an operand. Has backend specific meaning.",
|
||||
[HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]> {
|
||||
let cppNamespace = "::mlir::mhlo";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fast Fourier Transform Type enum definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// These mirror the XLA FftType proto enum.
|
||||
def HLO_FFT_TYPE_FFT : StrEnumAttrCase<"FFT">;
|
||||
def HLO_FFT_TYPE_IFFT : StrEnumAttrCase<"IFFT">;
|
||||
def HLO_FFT_TYPE_RFFT : StrEnumAttrCase<"RFFT">;
|
||||
def HLO_FFT_TYPE_IRFFT : StrEnumAttrCase<"IRFFT">;
|
||||
|
||||
def HLO_FftTypeAttr : StrEnumAttr<"FftType",
|
||||
"XLA fast fourier transform type.",
|
||||
[HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT,
|
||||
HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]> {
|
||||
let cppNamespace = "::mlir::mhlo";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Comparison op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// These mirror the XLA ComparisonDirection enum.
|
||||
def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">;
|
||||
def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">;
|
||||
def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">;
|
||||
def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">;
|
||||
def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">;
|
||||
def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">;
|
||||
|
||||
def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
|
||||
"Which comparison operation to perform.",
|
||||
[
|
||||
HLO_COMPARISON_DIRECTION_EQ,
|
||||
HLO_COMPARISON_DIRECTION_NE,
|
||||
HLO_COMPARISON_DIRECTION_GE,
|
||||
HLO_COMPARISON_DIRECTION_GT,
|
||||
HLO_COMPARISON_DIRECTION_LE,
|
||||
HLO_COMPARISON_DIRECTION_LT
|
||||
]> {
|
||||
let cppNamespace = "::mlir::mhlo";
|
||||
}
|
||||
|
||||
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
|
||||
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
|
||||
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
|
||||
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
|
||||
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
|
||||
|
||||
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
|
||||
"Which comparison type to use.",
|
||||
[
|
||||
HLO_COMPARISON_TYPE_FLOAT,
|
||||
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
|
||||
HLO_COMPARISON_TYPE_SIGNED,
|
||||
HLO_COMPARISON_TYPE_UNSIGNED
|
||||
]> {
|
||||
let cppNamespace = "::mlir::mhlo";
|
||||
}
|
||||
|
||||
// These mirror the XLA Dequantize mode string enum.
|
||||
def HLO_MIN_COMBINED : StrEnumAttrCase<"MIN_COMBINED">;
|
||||
|
||||
def HLO_DequantizeModeAttr : StrEnumAttr<"DequantizeMode",
|
||||
"Dequantization mode. Only MIN_COMBINED is supported.",
|
||||
[HLO_MIN_COMBINED]> {
|
||||
let cppNamespace = "::mlir::mhlo";
|
||||
}
|
||||
|
||||
// These mirror the XLA Transpose enum in Triangular Solve options.
|
||||
def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">;
|
||||
def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">;
|
||||
def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">;
|
||||
def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">;
|
||||
|
||||
def HLO_TransposeAttr : StrEnumAttr<"Transpose",
|
||||
"Transpose options",
|
||||
[
|
||||
HLO_TRANSPOSE_INVALID,
|
||||
HLO_NO_TRANSPOSE,
|
||||
HLO_TRANSPOSE,
|
||||
HLO_ADJOINT
|
||||
]> {
|
||||
let cppNamespace = "::mlir::mhlo";
|
||||
}
|
||||
|
||||
#endif // HLO_OPS_BASE_ENUMS
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -23,9 +23,9 @@ include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td"
|
||||
|
||||
|
||||
class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<LHLO_GPU_Dialect, mnemonic,
|
||||
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
|
||||
@ -92,30 +92,16 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
|
||||
// LMHLO ops representing convolution library functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ActivationModeNone : StrEnumAttrCase<"None">;
|
||||
def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">;
|
||||
def ActivationModeTanh : StrEnumAttrCase<"Relu">;
|
||||
def ActivationModeRelu : StrEnumAttrCase<"Relu">;
|
||||
def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">;
|
||||
def ActivationModeReluX : StrEnumAttrCase<"ReluX">;
|
||||
def ActivationModeBandPass : StrEnumAttrCase<"BandPass">;
|
||||
|
||||
def ActivationAttr : StrEnumAttr<"Activation",
|
||||
"Activation applied with fused convolution",
|
||||
[ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh,
|
||||
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
|
||||
ActivationModeBandPass]>;
|
||||
|
||||
def GpuConvolutionAttributes {
|
||||
dag attributes = !con(
|
||||
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
|
||||
ConvolutionAttributes.attributes,
|
||||
(ins F64Attr:$result_scale),
|
||||
(ins ConvolutionBackendConfigAttr:$backend_config));
|
||||
}
|
||||
|
||||
def GpuFusedConvolutionAttributes {
|
||||
dag attributes = !con(
|
||||
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
|
||||
ConvolutionAttributes.attributes,
|
||||
(ins F64Attr:$result_scale,
|
||||
ActivationAttr:$activation_mode,
|
||||
F64Attr:$side_input_scale),
|
||||
@ -179,9 +165,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||
DotDimensionNumbers:$dot_dimension_numbers,
|
||||
F64Attr:$alpha,
|
||||
F64Attr:$alpha_real,
|
||||
F64Attr:$alpha_imag,
|
||||
I64Attr:$batch_size,
|
||||
I64Attr:$algorithm);
|
||||
OptionalAttr<I64Attr>:$algorithm);
|
||||
}
|
||||
|
||||
// output = alpha(lhs * rhs) + beta * bias
|
||||
@ -192,10 +179,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||
DotDimensionNumbers:$dot_dimension_numbers,
|
||||
F64Attr:$alpha,
|
||||
F64Attr:$alpha_real,
|
||||
F64Attr:$alpha_imag,
|
||||
F64Attr:$beta,
|
||||
I64Attr:$batch_size,
|
||||
I64Attr:$algorithm);
|
||||
OptionalAttr<I64Attr>:$algorithm);
|
||||
}
|
||||
|
||||
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
|
||||
|
@ -0,0 +1,29 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
* ==============================================================================*/
|
||||
|
||||
// This file defines enums used in the LMHLO_GPU dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
|
||||
// Order matters, this .inc header is not self-contained, and relies on the
|
||||
// #includes above.
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h.inc"
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef LHLO_GPU_OPS_ENUMS
|
||||
#define LHLO_GPU_OPS_ENUMS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def ActivationModeNone : StrEnumAttrCase<"None">;
|
||||
def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">;
|
||||
def ActivationModeTanh : StrEnumAttrCase<"Tanh">;
|
||||
def ActivationModeRelu : StrEnumAttrCase<"Relu">;
|
||||
def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">;
|
||||
def ActivationModeReluX : StrEnumAttrCase<"ReluX">;
|
||||
def ActivationModeBandPass : StrEnumAttrCase<"BandPass">;
|
||||
|
||||
def ActivationAttr : StrEnumAttr<"Activation",
|
||||
"Activation applied with fused convolution",
|
||||
[ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh,
|
||||
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
|
||||
ActivationModeBandPass]> {
|
||||
let cppNamespace = "::mlir::lmhlo_gpu";
|
||||
}
|
||||
|
||||
#endif // LHLO_GPU_OPS_ENUMS
|
@ -1,4 +1,3 @@
|
||||
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -415,7 +415,7 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
|
||||
ConvolutionAttributes<LHLO_Dialect>.attributes);
|
||||
ConvolutionAttributes.attributes);
|
||||
}
|
||||
|
||||
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
||||
@ -654,7 +654,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
|
||||
OpBuilderDAG<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -44,6 +44,7 @@ add_mlir_library(MhloInferFusibilityOpInterface
|
||||
add_mlir_dialect_library(MhloDialect
|
||||
hlo_ops.cc
|
||||
hlo_ops_base_structs.cc
|
||||
hlo_ops_base_enums.cc
|
||||
|
||||
DEPENDS
|
||||
MLIRhlo_opsIncGen
|
||||
@ -70,6 +71,7 @@ target_link_libraries(LmhloDialect PUBLIC MLIRIR)
|
||||
add_mlir_dialect_library(LmhloGPUDialect
|
||||
lhlo_gpu_ops.cc
|
||||
lhlo_gpu_ops_structs.cc
|
||||
lhlo_gpu_ops_enums.cc
|
||||
|
||||
DEPENDS
|
||||
MLIRlhlo_gpu_opsIncGen
|
||||
|
@ -0,0 +1,18 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.cc.inc"
|
@ -0,0 +1,18 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h"
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc.inc"
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
@ -400,7 +400,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
||||
|
||||
// Generate a list of nested if/else statements to handle rank
|
||||
// specializations from 1-6.
|
||||
// specializations from 1 to `kMaxRankSpecialization`.
|
||||
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
rewriter, op, greater_rank, 1);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
||||
@ -419,13 +419,13 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||
}
|
||||
// Fire an assertion if none of the rank specializations applied (one of
|
||||
// the ranks was greater than 6).
|
||||
// the ranks was greater than `kMaxRankSpecialization`).
|
||||
else_builder.create<AssertOp>(
|
||||
loc,
|
||||
GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
|
||||
kMaxRankSpecialization),
|
||||
"Input for dynamic binary op lowering was of a rank greater than "
|
||||
"6");
|
||||
"Input for dynamic binary op lowering was of a rank greater than " +
|
||||
std::to_string(kMaxRankSpecialization));
|
||||
// Add the rank 6 specialization to the innermost else block.
|
||||
createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs,
|
||||
kMaxRankSpecialization);
|
||||
|
@ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>
|
||||
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
||||
alpha = 0.5,
|
||||
alpha_real = 0.5,
|
||||
alpha_imag = 0.0,
|
||||
batch_size = 1,
|
||||
algorithm = 0}
|
||||
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
|
||||
@ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
||||
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
||||
alpha = 0.5,
|
||||
alpha_real = 0.5,
|
||||
alpha_imag = 0.0,
|
||||
beta = 1.0,
|
||||
batch_size = 1,
|
||||
algorithm = 0}
|
||||
|
@ -490,6 +490,7 @@ cc_library(
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
"transforms/prepare_quantize_lstm.h",
|
||||
],
|
||||
deps = [
|
||||
"convert_type",
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <cstdarg>
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -46,10 +46,9 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -50,11 +50,10 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
@ -510,6 +509,12 @@ Operation* BuildVariableOp(const tflite::TensorT& tensor,
|
||||
return op.getOperation();
|
||||
}
|
||||
auto op = builder.create<tfl::ConstOp>(loc, value);
|
||||
if (!tensor.quantization->min.empty()) {
|
||||
if (auto stats_op =
|
||||
ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
|
||||
return stats_op;
|
||||
}
|
||||
}
|
||||
return op.getOperation();
|
||||
}
|
||||
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
|
||||
namespace tflite {
|
||||
// Converts a TFLite flatbuffer stored in `buffer` to a MLIR module
|
||||
|
@ -22,10 +22,9 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -494,6 +494,8 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
|
||||
assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
|
||||
auto result_shape_type = result_type.cast<ShapedType>();
|
||||
|
||||
if (!result_shape_type.hasStaticShape()) return {};
|
||||
|
||||
if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
SmallVector<APFloat, 16> new_values;
|
||||
const int num_elements = result_shape_type.getNumElements();
|
||||
@ -1740,7 +1742,8 @@ static LogicalResult Verify(LSTMOp op) {
|
||||
op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
|
||||
// If this lstm has layer normalization, this input value,
|
||||
// "forget_layer_norm_coefficients" should be a 1D tensor.
|
||||
if (forget_layer_norm_coefficients.getRank() != 1 ||
|
||||
if (!forget_layer_norm_coefficients.hasRank() ||
|
||||
forget_layer_norm_coefficients.getRank() != 1 ||
|
||||
forget_layer_norm_coefficients.getDimSize(0) != n_cell)
|
||||
return op.emitOpError(
|
||||
"coefficient inputs have more than 2 dimensions or "
|
||||
|
@ -31,10 +31,9 @@ limitations under the License.
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
@ -19,8 +19,8 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
||||
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
|
@ -32,7 +32,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
|
@ -283,6 +283,28 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testAvoidDilatedConvWithExpand(%arg0: tensor<*xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%cst_1 = constant dense<4> : tensor<2x2xi32>
|
||||
%cst_2 = constant dense<0> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<*xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %5 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testAvoidDilatedConvWithExpand
|
||||
// CHECK: "tf.SpaceToBatchND"
|
||||
// CHECK: "tf.ExpandDims"
|
||||
// CHECK: "tf.Conv2D"
|
||||
// CHECK: "tf.Squeeze"
|
||||
// CHECK: "tf.BatchToSpaceND"
|
||||
// CHECK: "tf.BiasAdd"
|
||||
}
|
||||
|
||||
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
|
@ -1,10 +1,14 @@
|
||||
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
|
||||
|
||||
// CHECK: effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>
|
||||
// CHECK: input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>
|
||||
// CHECK: input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>
|
||||
// CHECK: input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01:3.200000e+01>>>
|
||||
// CHECK: input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>
|
||||
// CHECK-DAG: %[[input_18:.*]] = "quant.stats"({{.*}}) {layerStats = dense<[-8.000000e-01, 1.600000e+00]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-DAG: %[[input_19:.*]] = "quant.stats"({{.*}}) {layerStats = dense<[-2.000000e+00, 4.000000e+00]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"({{.*}}, %[[input_18]], %[[input_19]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}})
|
||||
// CHECK-SAME: effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>
|
||||
// CHECK-SAME: input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>
|
||||
// CHECK-SAME: input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>
|
||||
// CHECK-SAME: input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01:3.200000e+01>>>
|
||||
// CHECK-SAME: input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>
|
||||
|
||||
{
|
||||
"version": 3,
|
||||
@ -110,8 +114,8 @@
|
||||
"name": "input_activation_state18",
|
||||
"is_variable": true,
|
||||
"quantization": {
|
||||
"min": [-0.9],
|
||||
"max": [0.9]
|
||||
"min": [-0.8],
|
||||
"max": [1.6]
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -119,8 +123,8 @@
|
||||
"name": "input_cell_state19",
|
||||
"is_variable": true,
|
||||
"quantization": {
|
||||
"min": [-0.8],
|
||||
"max": [0.8]
|
||||
"min": [-2.0],
|
||||
"max": [4.0]
|
||||
}
|
||||
},
|
||||
{
|
||||
|
@ -810,6 +810,16 @@ func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4
|
||||
return %24 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Coefficient inputs of LSTM op have unknown rank.
|
||||
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<*xf32>) -> tensor<1x4xf32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
// expected-error @+1 {{'tfl.lstm' op coefficient inputs have more than 2 dimensions or don't match the dimension with input operand `input_to_output_weights`.}}
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<*xf32>) -> tensor<1x4xf32>
|
||||
return %24 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -343,6 +343,18 @@ func @fuseAddIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf
|
||||
// CHECK-NEXT: return %[[fc]] : tensor<4x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @doNotFuseAddIntoFollowingFullyConnected
|
||||
func @doNotFuseAddIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>, %arg1: tensor<*xf32>) -> tensor<4x2xf32> {
|
||||
%cst1 = constant dense<1.5> : tensor<f32>
|
||||
%0 = "tfl.add"(%arg0, %cst1) {fused_activation_function = "NONE"} : (tensor<4x2xf32>, tensor<f32>) -> tensor<4x2xf32>
|
||||
%cst = constant dense<2.0> : tensor<2xf32>
|
||||
%1 = "tfl.fully_connected"(%0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<*xf32>, tensor<2xf32>) -> tensor<4x2xf32>
|
||||
return %1 : tensor<4x2xf32>
|
||||
|
||||
// CHECK: "tfl.add"
|
||||
// CHECK: "tfl.fully_connected"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoFollowingFullyConnected
|
||||
func @fuseMulIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||
%cst2 = constant dense<1.5> : tensor<f32>
|
||||
|
@ -760,4 +760,13 @@ func @depthwise_conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x
|
||||
// CHECK: "tf.DepthwiseConv2dNative"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: strided_slice_unranked_input
|
||||
func @strided_slice_unranked_input(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
%18 = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%57 = "tf.Const"() {value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%534 = "tf.StridedSlice"(%arg0, %57, %57, %18) {begin_mask = 11 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 11 : i64, new_axis_mask = 4 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<*xf32>
|
||||
return %534 : tensor<*xf32>
|
||||
// CHECK: "tf.StridedSlice"
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -165,3 +165,28 @@ func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x
|
||||
// CHECK-SAME: [[VAL_91]], [[VAL_52]] : tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @whileDifferentResultShapes
|
||||
func @whileDifferentResultShapes(%arg0: tensor<i32>) -> tensor<?xf32>
|
||||
attributes {tf.entry_function = {outputs = "result"}} {
|
||||
%cst0 = constant dense<5> : tensor<i32> loc("N")
|
||||
%cst1 = constant dense<3.0> : tensor<1xf32> loc("val")
|
||||
|
||||
%0:2 = "tfl.while"(%cst0, %cst1) ( {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
||||
%1 = "tfl.sub"(%arg2, %arg0) {fused_activation_function = "NONE"} :
|
||||
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
"tfl.yield"(%1, %2) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<?xf32>) loc("WhileOp")
|
||||
|
||||
// CHECK: (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<?xf32>, tensor<i32>)
|
||||
return %0#1 : tensor<?xf32>
|
||||
}
|
||||
|
@ -17,8 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
@ -24,10 +24,9 @@ limitations under the License.
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/AsmState.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
|
@ -175,6 +175,9 @@ InspectResult InspectWeight(
|
||||
} else if (auto cst = dyn_cast<QConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
type = cst.getType().cast<ShapedType>();
|
||||
} else {
|
||||
result.can_compress = false;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Currently we only support compressing weights of ops:
|
||||
@ -222,6 +225,8 @@ std::vector<T> BuildSparsityParameterAttribute(
|
||||
} else if (auto cst = dyn_cast<QConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
type = cst.getType().cast<ShapedType>();
|
||||
} else {
|
||||
assert(false && "Expected a constant-like op");
|
||||
}
|
||||
const int dims_count = type.getRank();
|
||||
std::vector<int> shape(dims_count);
|
||||
|
@ -185,7 +185,12 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
|
||||
stb_op.block_shape(), bts_op.block_shape(), rewriter);
|
||||
if (!dilations_attr.hasValue()) return failure();
|
||||
op.setAttr("dilations", dilations_attr.getValue());
|
||||
|
||||
if (expand_op) {
|
||||
if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(b/149936532): Check that the input width & height are multiples of
|
||||
// dilation rate.
|
||||
@ -234,6 +239,9 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
}
|
||||
}
|
||||
|
||||
// Set dilations
|
||||
op.setAttr("dilations", dilations_attr.getValue());
|
||||
|
||||
if (expand_op) {
|
||||
// If there is `expand_op`, we need to rewire the inputs to bypass the
|
||||
// `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning
|
||||
|
@ -35,10 +35,9 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
|
@ -351,10 +351,10 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
||||
// to properly broadcast the scalar to `{num_channels}` shape.
|
||||
|
||||
// Get the number of channels if possible.
|
||||
auto filter_type = filter.getType().cast<ShapedType>();
|
||||
auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
|
||||
// Filter must be a `2D` tensor with `{num_channels, num_features}`
|
||||
// shape. The following check is rejecting unknown rank (-1).
|
||||
if (filter_type.getRank() != 2) {
|
||||
if (filter_type == nullptr || filter_type.getRank() != 2) {
|
||||
return failure();
|
||||
}
|
||||
int num_channels = filter_type.getShape()[0];
|
||||
|
@ -19,8 +19,8 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
@ -26,11 +26,10 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
|
@ -14,7 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass applies quantization propagation on TFLite dialect.
|
||||
#include <cmath>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
@ -30,7 +29,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
@ -44,9 +43,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/tools/optimize/operator_property.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<std::string> quantize_allowlist(
|
||||
@ -61,6 +59,12 @@ static llvm::cl::opt<bool> quantize_signed(
|
||||
llvm::cl::desc("signed inference type. Only used in tests"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> post_training_quantize(
|
||||
"tfl-test-post-training-quantize", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("enable post training quantization. Only used in tests"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> disable_per_channel(
|
||||
"tfl-disable-per-channel", llvm::cl::value_desc("bool"),
|
||||
@ -91,10 +95,9 @@ class PrepareQuantizePass
|
||||
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
||||
// This is only used by test.
|
||||
explicit PrepareQuantizePass() {
|
||||
if (quantize_signed)
|
||||
quant_specs_.inference_type = tensorflow::DT_QINT8;
|
||||
else
|
||||
quant_specs_.inference_type = tensorflow::DT_QUINT8;
|
||||
quant_specs_.inference_type =
|
||||
quantize_signed ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
|
||||
quant_specs_.post_training_quantization = post_training_quantize;
|
||||
}
|
||||
|
||||
// Constructor used by manually creating the pass.
|
||||
@ -318,148 +321,9 @@ bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
|
||||
using PrepareQuantStats =
|
||||
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
|
||||
// Calculates the minimum power of two that is not less than the value.
|
||||
double power_of_two_bound(double value) {
|
||||
return std::pow(2, std::ceil(std::log2(value)));
|
||||
}
|
||||
|
||||
// Quantize recurrent input of LSTM with 16 bits.
|
||||
template <typename SourceOp, typename Q, typename DQ>
|
||||
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
||||
public:
|
||||
explicit ConvertLstmStatsToQDQs(MLIRContext* context)
|
||||
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
|
||||
LogicalResult matchAndRewrite(SourceOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
tflite::optimize::operator_property::OpVariant lstm_variant;
|
||||
if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
|
||||
lstm_variant.op_code = tflite::BuiltinOperator_LSTM;
|
||||
} else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(
|
||||
op.getOperation())) {
|
||||
lstm_variant.op_code =
|
||||
tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
|
||||
} else {
|
||||
op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
|
||||
return failure();
|
||||
}
|
||||
lstm_variant.use_projection =
|
||||
!op.projection_weights().getType().template isa<NoneType>();
|
||||
lstm_variant.use_peephole =
|
||||
!op.cell_to_output_weights().getType().template isa<NoneType>();
|
||||
lstm_variant.use_peephole =
|
||||
!op.cell_to_output_weights().getType().template isa<NoneType>();
|
||||
lstm_variant.use_layer_norm =
|
||||
!op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
|
||||
|
||||
auto lstm_property =
|
||||
tflite::optimize::operator_property::GetOperatorProperty(lstm_variant);
|
||||
|
||||
// Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
|
||||
const std::vector<std::string> intermediate_attributes = {
|
||||
"input_to_input_intermediate", "input_to_forget_intermediate",
|
||||
"input_to_cell_intermediate", "input_to_output_intermediate",
|
||||
"effective_hidden_scale_intermediate"};
|
||||
|
||||
for (auto& enumerated_intermediates : lstm_property.intermediates) {
|
||||
int index = enumerated_intermediates.first;
|
||||
auto& tensor_property = enumerated_intermediates.second;
|
||||
// intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
|
||||
if (!lstm_variant.use_layer_norm && index != 4) {
|
||||
continue;
|
||||
}
|
||||
// intermediate tensor 4 is only used with projection.
|
||||
if (!lstm_variant.use_projection && index == 4) {
|
||||
continue;
|
||||
}
|
||||
TypeAttr attr =
|
||||
op.template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
|
||||
|
||||
if (!attr) {
|
||||
op.emitError()
|
||||
<< op.getOperationName()
|
||||
<< " requires quantization values for intermediate tensor "
|
||||
<< intermediate_attributes[index];
|
||||
return failure();
|
||||
}
|
||||
auto quantized_type =
|
||||
QuantizedType::getQuantizedElementType(attr.getValue());
|
||||
if (!quantized_type) {
|
||||
op.emitError() << intermediate_attributes[index]
|
||||
<< " is not quantized.";
|
||||
return failure();
|
||||
}
|
||||
auto calibrated_type =
|
||||
quantized_type.dyn_cast<quant::CalibratedQuantizedType>();
|
||||
if (!calibrated_type) {
|
||||
int num_storage_bits = quantized_type.getStorageTypeIntegralWidth();
|
||||
if (tensor_property.number_of_bits != num_storage_bits) {
|
||||
op.emitError() << intermediate_attributes[index]
|
||||
<< " is expected to be quantized with "
|
||||
<< tensor_property.number_of_bits << " bits, but got "
|
||||
<< num_storage_bits << " bits instead.";
|
||||
return failure();
|
||||
}
|
||||
continue; // skip if it is already quantized.
|
||||
}
|
||||
quant::UniformQuantizedType qtype;
|
||||
if (tensor_property.number_of_bits == 8) {
|
||||
qtype = quant::fakeQuantAttrsToType(
|
||||
op.getLoc(), tensor_property.number_of_bits,
|
||||
calibrated_type.getMin(), calibrated_type.getMax(),
|
||||
/*narrowRange=*/false, calibrated_type.getExpressedType(),
|
||||
/*isSigned=*/false);
|
||||
} else if (tensor_property.number_of_bits == 16) {
|
||||
double max = std::max(std::abs(calibrated_type.getMin()),
|
||||
std::abs(calibrated_type.getMax()));
|
||||
qtype = quant::fakeQuantAttrsToType(
|
||||
op.getLoc(), tensor_property.number_of_bits, -max, max,
|
||||
/*narrowRange=*/true, calibrated_type.getExpressedType(),
|
||||
/*isSigned=*/true);
|
||||
} else {
|
||||
op.emitError() << "Unsupported quantization bits: "
|
||||
<< tensor_property.number_of_bits;
|
||||
return failure();
|
||||
}
|
||||
|
||||
op.setAttr(intermediate_attributes[index],
|
||||
TypeAttr::get(qtype.castFromExpressedType(
|
||||
qtype.castToExpressedType(attr.getValue()))));
|
||||
}
|
||||
|
||||
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||
op.input_cell_state().getDefiningOp());
|
||||
// Recurrent input is be used within an LSTM, and thus should have one use.
|
||||
if (!stats_op || !stats_op.getResult().hasOneUse()) {
|
||||
return failure();
|
||||
}
|
||||
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
|
||||
if (!stats) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
double max = std::max(
|
||||
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}))),
|
||||
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
|
||||
double bound = power_of_two_bound(max);
|
||||
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
|
||||
// Set flags to 1 for signed type.
|
||||
quant::QuantizedType quant_type = UniformQuantizedType::getChecked(
|
||||
quant::QuantizationFlags::Signed,
|
||||
IntegerType::get(16, expressed.getContext()), expressed,
|
||||
/*scale=*/bound / 32768.0, /*zeroPoint=*/0, llvm::minIntN(16),
|
||||
llvm::maxIntN(16), op.getLoc());
|
||||
|
||||
rewriter.setInsertionPointAfter(stats_op);
|
||||
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
|
||||
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
|
||||
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
using PrepareLstmQuantStats =
|
||||
ConvertLstmStatsToQDQs<TFL::UnidirectionalSequenceLSTMOp,
|
||||
quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
TFL::ConvertLstmStatsToQDQs<TFL::UnidirectionalSequenceLSTMOp,
|
||||
quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
|
||||
void PrepareQuantizePass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
@ -503,7 +367,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
|
||||
}
|
||||
patterns.insert<PrepareLstmQuantStats>(ctx);
|
||||
patterns.insert<PrepareLstmQuantStats>(ctx, quant_specs_);
|
||||
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
|
||||
SanityCheckAndAdjustment(func);
|
||||
|
199
tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h
Normal file
199
tensorflow/compiler/mlir/lite/transforms/prepare_quantize_lstm.h
Normal file
@ -0,0 +1,199 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Transform pass for LSTMs.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/tools/optimize/operator_property.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The prepare-quantize Pass for LSTM.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Calculates the minimum power of two that is not less than the value.
|
||||
inline double power_of_two_bound(double value) {
|
||||
return std::pow(2, std::ceil(std::log2(value)));
|
||||
}
|
||||
|
||||
namespace operator_property = ::tflite::optimize::operator_property;
|
||||
|
||||
// Quantize recurrent input of LSTM with 16 bits.
|
||||
template <typename SourceOp, typename Q, typename DQ>
|
||||
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
||||
public:
|
||||
ConvertLstmStatsToQDQs(MLIRContext* context,
|
||||
const QuantizationSpecs& quant_specs)
|
||||
|
||||
: OpRewritePattern<SourceOp>(context, /*benefit=*/2),
|
||||
quant_specs(quant_specs) {}
|
||||
LogicalResult matchAndRewrite(SourceOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
operator_property::OpVariant lstm_variant;
|
||||
if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
|
||||
lstm_variant.op_code = tflite::BuiltinOperator_LSTM;
|
||||
} else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(
|
||||
op.getOperation())) {
|
||||
lstm_variant.op_code =
|
||||
tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
|
||||
} else {
|
||||
op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
|
||||
return failure();
|
||||
}
|
||||
lstm_variant.use_projection =
|
||||
!op.projection_weights().getType().template isa<NoneType>();
|
||||
lstm_variant.use_peephole =
|
||||
!op.cell_to_output_weights().getType().template isa<NoneType>();
|
||||
lstm_variant.use_peephole =
|
||||
!op.cell_to_output_weights().getType().template isa<NoneType>();
|
||||
lstm_variant.use_layer_norm =
|
||||
!op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
|
||||
|
||||
auto lstm_property = operator_property::GetOperatorProperty(lstm_variant);
|
||||
|
||||
// Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
|
||||
const std::vector<std::string> intermediate_attributes = {
|
||||
"input_to_input_intermediate", "input_to_forget_intermediate",
|
||||
"input_to_cell_intermediate", "input_to_output_intermediate",
|
||||
"effective_hidden_scale_intermediate"};
|
||||
|
||||
for (auto& enumerated_intermediates : lstm_property.intermediates) {
|
||||
int index = enumerated_intermediates.first;
|
||||
auto& tensor_property = enumerated_intermediates.second;
|
||||
// intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
|
||||
if (!lstm_variant.use_layer_norm && index != 4) {
|
||||
continue;
|
||||
}
|
||||
// intermediate tensor 4 is only used with projection.
|
||||
if (!lstm_variant.use_projection && index == 4) {
|
||||
continue;
|
||||
}
|
||||
TypeAttr attr =
|
||||
op.template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
|
||||
|
||||
if (!attr) {
|
||||
op.emitError()
|
||||
<< op.getOperationName()
|
||||
<< " requires quantization values for intermediate tensor "
|
||||
<< intermediate_attributes[index];
|
||||
return failure();
|
||||
}
|
||||
auto quantized_type =
|
||||
QuantizedType::getQuantizedElementType(attr.getValue());
|
||||
if (!quantized_type) {
|
||||
op.emitError() << intermediate_attributes[index]
|
||||
<< " is not quantized.";
|
||||
return failure();
|
||||
}
|
||||
auto calibrated_type =
|
||||
quantized_type.dyn_cast<quant::CalibratedQuantizedType>();
|
||||
if (!calibrated_type) {
|
||||
int num_storage_bits = quantized_type.getStorageTypeIntegralWidth();
|
||||
if (tensor_property.number_of_bits != num_storage_bits) {
|
||||
op.emitError() << intermediate_attributes[index]
|
||||
<< " is expected to be quantized with "
|
||||
<< tensor_property.number_of_bits << " bits, but got "
|
||||
<< num_storage_bits << " bits instead.";
|
||||
return failure();
|
||||
}
|
||||
continue; // skip if it is already quantized.
|
||||
}
|
||||
quant::UniformQuantizedType qtype;
|
||||
if (tensor_property.number_of_bits == 8) {
|
||||
qtype = quant::fakeQuantAttrsToType(
|
||||
op.getLoc(), tensor_property.number_of_bits,
|
||||
calibrated_type.getMin(), calibrated_type.getMax(),
|
||||
/*narrowRange=*/false, calibrated_type.getExpressedType(),
|
||||
/*isSigned=*/quant_specs.IsSignedInferenceType());
|
||||
} else if (tensor_property.number_of_bits == 16) {
|
||||
double max = std::max(std::abs(calibrated_type.getMin()),
|
||||
std::abs(calibrated_type.getMax()));
|
||||
qtype = quant::fakeQuantAttrsToType(
|
||||
op.getLoc(), tensor_property.number_of_bits, -max, max,
|
||||
/*narrowRange=*/true, calibrated_type.getExpressedType(),
|
||||
/*isSigned=*/true);
|
||||
} else {
|
||||
op.emitError() << "Unsupported quantization bits: "
|
||||
<< tensor_property.number_of_bits;
|
||||
return failure();
|
||||
}
|
||||
|
||||
op.setAttr(intermediate_attributes[index],
|
||||
TypeAttr::get(qtype.castFromExpressedType(
|
||||
qtype.castToExpressedType(attr.getValue()))));
|
||||
}
|
||||
|
||||
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||
op.input_cell_state().getDefiningOp());
|
||||
// Recurrent input is be used within an LSTM, and thus should have one use.
|
||||
if (!stats_op || !stats_op.getResult().hasOneUse()) {
|
||||
return failure();
|
||||
}
|
||||
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
|
||||
if (!stats) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
double max = std::max(
|
||||
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}))),
|
||||
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
|
||||
double bound = power_of_two_bound(max);
|
||||
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
|
||||
// Set flags to 1 for signed type.
|
||||
quant::QuantizedType quant_type = UniformQuantizedType::getChecked(
|
||||
quant::QuantizationFlags::Signed,
|
||||
IntegerType::get(16, expressed.getContext()), expressed,
|
||||
/*scale=*/bound / 32768.0, /*zeroPoint=*/0, llvm::minIntN(16),
|
||||
llvm::maxIntN(16), op.getLoc());
|
||||
|
||||
rewriter.setInsertionPointAfter(stats_op);
|
||||
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
|
||||
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
|
||||
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
QuantizationSpecs quant_specs;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM
|
@ -44,7 +44,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
@ -526,9 +526,12 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
|
||||
// Insert a new reshape op.
|
||||
Value original_input = strided_slice_op.input();
|
||||
// TODO(b/174267775): Make sure that the input type has ranked tensor type.
|
||||
RankedTensorType original_input_type =
|
||||
original_input.getType().cast<RankedTensorType>();
|
||||
original_input.getType().dyn_cast<RankedTensorType>();
|
||||
if (!original_input_type) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
const ArrayRef<int64_t> &original_input_shape =
|
||||
original_input_type.getShape();
|
||||
SmallVector<int64_t, 4> revised_shape;
|
||||
|
@ -22,9 +22,9 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
|
@ -22,10 +22,9 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
|
@ -211,24 +211,25 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
// change, so replace with new while op.
|
||||
if (extra_operands.empty()) return;
|
||||
|
||||
Operation* op = while_op.getOperation();
|
||||
const int operands_size = while_op.getNumOperands() + extra_operands.size();
|
||||
SmallVector<Value, 4> operands;
|
||||
operands.reserve(operands_size);
|
||||
operands.append(while_op.getOperands().begin(), while_op.getOperands().end());
|
||||
operands.append(extra_operands.begin(), extra_operands.end());
|
||||
SmallVector<Type, 4> new_types;
|
||||
operands.reserve(types.size());
|
||||
new_types.reserve(operands.size());
|
||||
auto add_operand = [&](Value v) {
|
||||
operands.push_back(v);
|
||||
new_types.push_back(v.getType());
|
||||
};
|
||||
for (auto operand : op->getOperands()) add_operand(operand);
|
||||
for (auto operand : extra_operands) add_operand(operand);
|
||||
new_types.reserve(operands_size);
|
||||
new_types.append(while_op.getResultTypes().begin(),
|
||||
while_op.getResultTypes().end());
|
||||
for (auto extra_operand : extra_operands)
|
||||
new_types.push_back(extra_operand.getType());
|
||||
|
||||
Operation* new_op = OpBuilder(op).insert(Operation::create(
|
||||
op->getLoc(), op->getName(), new_types, operands, op->getAttrs(),
|
||||
/*successors=*/{}, /*numRegions=*/2));
|
||||
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
|
||||
op->replaceAllUsesWith(new_op->getResults().take_front(op->getNumResults()));
|
||||
op->erase();
|
||||
auto new_while_op = OpBuilder(while_op).create<WhileOp>(
|
||||
while_op.getLoc(), new_types, operands, while_op.getAttrs());
|
||||
new_while_op.cond().takeBody(while_op.cond());
|
||||
new_while_op.body().takeBody(while_op.body());
|
||||
while_op.replaceAllUsesWith(
|
||||
new_while_op.getResults().take_front(while_op.getNumResults()));
|
||||
while_op.erase();
|
||||
}
|
||||
|
||||
void WhileOutlinePass::runOnOperation() {
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/InitAllPasses.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
@ -44,6 +44,7 @@ filegroup(
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:SideEffectTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
],
|
||||
@ -285,8 +286,10 @@ gentbl(
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tf_device_ops.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:SideEffectTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
|
||||
],
|
||||
test = True,
|
||||
)
|
||||
@ -408,6 +411,7 @@ cc_library(
|
||||
":tensorflow_traits",
|
||||
":tensorflow_types",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -453,6 +457,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -495,6 +500,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -533,6 +539,7 @@ cc_library(
|
||||
":tensorflow_remaining_ops",
|
||||
":tensorflow_tfrt_ops",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -640,6 +647,7 @@ cc_library(
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:CallOpInterfacesIncGen",
|
||||
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -815,6 +823,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":tensorflow_op_interfaces",
|
||||
":tensorflow_side_effects",
|
||||
":tensorflow_types",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
@ -856,6 +865,7 @@ cc_library(
|
||||
"transforms/cluster_outlining.cc",
|
||||
"transforms/cluster_tf_ops_pass.cc",
|
||||
"transforms/collection_ops_util.cc",
|
||||
"transforms/constant_op_device_assignment.cc",
|
||||
"transforms/contraction_fusion.cc",
|
||||
"transforms/decompose_resource_ops_pass.cc",
|
||||
"transforms/device_index_selector.cc",
|
||||
@ -882,7 +892,6 @@ cc_library(
|
||||
"transforms/optimize.cc",
|
||||
"transforms/outside_compiled_to_host_launch.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
"transforms/parallelize_embedding_params_ops_pass.cc",
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
"transforms/readonly_references_to_resources.cc",
|
||||
"transforms/region_control_flow_to_functional.cc",
|
||||
@ -995,6 +1004,7 @@ cc_library(
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Rewrite",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
|
@ -21,8 +21,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "mlir/Analysis/CallGraph.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
|
@ -30,8 +30,8 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -94,20 +95,37 @@ struct SideEffects {
|
||||
|
||||
using ResourceSideEffectsByValue = llvm::SmallDenseMap<Value, SideEffects>;
|
||||
|
||||
bool MustExecute(const MemoryEffects::EffectInstance& effect) {
|
||||
if (llvm::isa<ResourceEffects::TPUEmbedding>(effect.getResource())) {
|
||||
assert(!effect.getValue() && !effect.getParameters() &&
|
||||
isa<MemoryEffects::Write>(effect.getEffect()));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Collects memory side effects for an operation by value (operands and
|
||||
// results).
|
||||
ResourceSideEffectsByValue GetResourceInfoForOp(Operation* op) {
|
||||
ResourceSideEffectsByValue resource_info;
|
||||
|
||||
void GetResourceInfoForOp(Operation* op,
|
||||
ResourceSideEffectsByValue& resource_info,
|
||||
bool& must_execute) {
|
||||
auto interface = dyn_cast<MemoryEffectOpInterface>(op);
|
||||
if (!interface) return resource_info;
|
||||
if (!interface) return;
|
||||
|
||||
llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
|
||||
interface.getEffects(effects);
|
||||
|
||||
for (auto& effect : effects) {
|
||||
if (MustExecute(effect)) {
|
||||
must_execute = true;
|
||||
continue;
|
||||
}
|
||||
// TODO(lyandy): Support effects with no value defined.
|
||||
if (!effect.getValue()) return ResourceSideEffectsByValue();
|
||||
if (!effect.getValue()) {
|
||||
resource_info.clear();
|
||||
must_execute = false;
|
||||
return;
|
||||
}
|
||||
auto it = resource_info.try_emplace(effect.getValue());
|
||||
auto& side_effect = it.first->getSecond();
|
||||
auto* resource_effect = effect.getEffect();
|
||||
@ -120,11 +138,11 @@ ResourceSideEffectsByValue GetResourceInfoForOp(Operation* op) {
|
||||
} else if (isa<MemoryEffects::Write>(resource_effect)) {
|
||||
side_effect.write = true;
|
||||
} else {
|
||||
return ResourceSideEffectsByValue();
|
||||
resource_info.clear();
|
||||
must_execute = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
return resource_info;
|
||||
}
|
||||
|
||||
// Checks if a value is a result of `op`.
|
||||
@ -328,6 +346,7 @@ void SideEffectAnalysisInfo::AnalyzeRegion(
|
||||
// We explicitly iterates through the regions and blocks, in order to handle
|
||||
// different nested regions separately.
|
||||
for (auto& block : *region) {
|
||||
llvm::SmallPtrSet<Operation*, 8> non_resource_control_predecessors;
|
||||
for (auto& op : block) {
|
||||
for (Region& child : op.getRegions()) {
|
||||
SideEffectAnalysisInfo child_analysis(&child, alias_analysis);
|
||||
@ -340,10 +359,21 @@ void SideEffectAnalysisInfo::AnalyzeRegion(
|
||||
// We do not need explicit control edges for declaration ops.
|
||||
if (OpIsDeclaration(&op, alias_analysis)) continue;
|
||||
|
||||
auto resource_op_info = GetResourceInfoForOp(&op);
|
||||
ResourceSideEffectsByValue resource_op_info;
|
||||
bool must_execute = false;
|
||||
GetResourceInfoForOp(&op, resource_op_info, must_execute);
|
||||
|
||||
if (resource_op_info.empty() && OpIsKnownToHaveNoSideEffect(&op))
|
||||
continue;
|
||||
|
||||
if (resource_op_info.empty() && must_execute) {
|
||||
// Add unknown resource ops as predecessors of the op that must execute,
|
||||
// to guarantee ordering between unknown resource ops.
|
||||
AddPredecessorsForAccess(kUnknownResourceId, &op, /*read_only=*/false);
|
||||
non_resource_control_predecessors.insert(&op);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (IsResourceOpAllocOnly(&op, resource_op_info)) continue;
|
||||
|
||||
auto resource_ids_by_value =
|
||||
@ -399,6 +429,14 @@ void SideEffectAnalysisInfo::AnalyzeRegion(
|
||||
if (!resource_ids_by_value.hasValue()) {
|
||||
// Update access info for unknown resource.
|
||||
TrackAccess(kUnknownResourceId, &op, read_only);
|
||||
// Add ops that must execute to unknown resource op predecessors.
|
||||
auto& control_predecessors = control_predecessors_[&op];
|
||||
control_predecessors.insert(non_resource_control_predecessors.begin(),
|
||||
non_resource_control_predecessors.end());
|
||||
// Ops that must execute currently tracked are cleared as transitively
|
||||
// unknown resource ops will allow for such ops to be transitively
|
||||
// reachable.
|
||||
non_resource_control_predecessors.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -22,10 +22,9 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
|
@ -20,10 +20,12 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_
|
||||
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_device {
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
#define TF_DEVICE_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow Device Dialect definitions
|
||||
@ -85,7 +87,7 @@ This op captures all needed live-in values.
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TfDevice_ReturnOp : TfDevice_Op<"return", [Terminator]> {
|
||||
def TfDevice_ReturnOp : TfDevice_Op<"return", [NoSideEffect, ReturnLike, Terminator]> {
|
||||
let summary = [{
|
||||
The `tf_device.return` operation terminates and returns values from a
|
||||
`tf_device` dialect operation.
|
||||
|
@ -30,8 +30,8 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -140,6 +140,7 @@ def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
|
||||
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
|
||||
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
|
||||
def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
|
||||
def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">;
|
||||
|
||||
def TF_VariableRead : MemRead<TF_VariableResource>;
|
||||
def TF_StackRead : MemRead<TF_StackResource>;
|
||||
@ -174,6 +175,8 @@ def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
|
||||
def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
|
||||
|
||||
def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -82,9 +82,9 @@ struct ResourceHandle {
|
||||
// Make ResourceHandle hashable.
|
||||
friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle);
|
||||
|
||||
std::string container;
|
||||
std::string name;
|
||||
std::string device;
|
||||
StringRef container;
|
||||
StringRef name;
|
||||
StringRef device;
|
||||
Operation* op = nullptr;
|
||||
};
|
||||
|
||||
|
@ -41,9 +41,9 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
|
@ -22,14 +22,14 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
|
@ -30,8 +30,10 @@ limitations under the License.
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
|
||||
@ -352,7 +354,8 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
}
|
||||
|
||||
def TF_YieldOp : TF_Op<"Yield",
|
||||
[Terminator, ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> {
|
||||
[NoSideEffect, ReturnLike, Terminator,
|
||||
ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> {
|
||||
let summary = "Yield operation";
|
||||
|
||||
let description = [{
|
||||
|
@ -43,9 +43,9 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
|
@ -19,14 +19,14 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
|
@ -44,9 +44,9 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
@ -1947,6 +1947,44 @@ ResourceHandleValueAndId SummaryWriterOp::GetResourceHandleValueAndId(
|
||||
next_id);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TPUExecuteAndUpdateVariablesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(TPUExecuteAndUpdateVariablesOp op) {
|
||||
int num_resource_args = 0;
|
||||
for (Type arg_type : op.args().getTypes())
|
||||
if (arg_type.cast<TensorType>().getElementType().isa<ResourceType>())
|
||||
++num_resource_args;
|
||||
|
||||
auto check_attr = [&](ArrayAttr indices, llvm::StringRef name,
|
||||
int min) -> LogicalResult {
|
||||
if (indices.size() != num_resource_args)
|
||||
return op.emitOpError()
|
||||
<< "requires '" << name
|
||||
<< "' to be the same size as number of resource handles in 'args' "
|
||||
"("
|
||||
<< num_resource_args << "), but got " << indices.size();
|
||||
|
||||
for (auto entry : llvm::enumerate(indices.getValue())) {
|
||||
auto int_attr = entry.value().cast<IntegerAttr>();
|
||||
if (int_attr.getInt() < min)
|
||||
return op.emitOpError()
|
||||
<< "requires '" << name << "' to contain values of at least "
|
||||
<< min << ", but got " << int_attr.getInt() << " at index "
|
||||
<< entry.index();
|
||||
}
|
||||
|
||||
return success();
|
||||
};
|
||||
|
||||
return failure(
|
||||
failed(check_attr(op.device_var_reads_indices(),
|
||||
/*name=*/"device_var_reads_indices", /*min=*/0)) ||
|
||||
failed(check_attr(op.device_var_updates_indices(),
|
||||
/*name=*/"device_var_updates_indices", /*min=*/-1)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorListReserveOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2128,7 +2166,8 @@ class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
|
||||
// If the input is an unranked tensor, cannpt rewrite.
|
||||
if (!type) return failure();
|
||||
|
||||
// Expected return type of the ToBool operation.
|
||||
// Expected return type of the ToBool operation. The return type of ToBool
|
||||
// operation is always 0D tensor of bool type.
|
||||
auto result_type = op.getResult().getType().cast<RankedTensorType>();
|
||||
|
||||
// If input is already a tensor<i1>, it can be folded into an identity.
|
||||
@ -2533,21 +2572,20 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
|
||||
TypeRange body_input,
|
||||
TypeRange body_result) {
|
||||
// Collect all the type lists for the op so that different pairs of type lists
|
||||
// can be compared for the compatibility.
|
||||
constexpr int kNumTypeLists = 5;
|
||||
const std::array<TypeRangeWithDesc, kNumTypeLists> type_lists = {{
|
||||
{op->getOperandTypes(), "input"},
|
||||
TypeRange body_result,
|
||||
bool shape_invariant) {
|
||||
const TypeRangeWithDesc input_type = {op->getOperandTypes(), "input"};
|
||||
const TypeRangeWithDesc result_type = {op->getResultTypes(), "result"};
|
||||
constexpr int kNumRegionTypeLists = 3;
|
||||
const std::array<TypeRangeWithDesc, kNumRegionTypeLists> region_types = {{
|
||||
{body_result, "body result"},
|
||||
{op->getResultTypes(), "result"},
|
||||
{cond_input, "condition input"},
|
||||
{body_input, "body input"},
|
||||
}};
|
||||
|
||||
// A pair of type lists should be cast compatible with each other if one is
|
||||
// converted to the another for a function call or assignment or there is a
|
||||
// common source of inputs for both. Therefore, the While op requires the
|
||||
// common source of inputs for both. Therefore, the While op requires the
|
||||
// following pairs of type lists to be cast compatible for the tensor_cast
|
||||
// operation:
|
||||
//
|
||||
@ -2556,7 +2594,8 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
|
||||
// * Operands and body inputs to call the body function for the first
|
||||
// iteration if the cond functions returns True or equivalent result.
|
||||
// * Operands and results to assign cond function arguments to op results if
|
||||
// the cond function returns False or equivalent result.
|
||||
// the cond function returns False or equivalent result. If the op is shape
|
||||
// invariant, this does not hold as shapes can differ.
|
||||
// * All three pairs using cond inputs, body inputs and results as operand is
|
||||
// a common source for all three.
|
||||
// * Body result and cond inputs to call the cond function for the subsequent
|
||||
@ -2565,17 +2604,28 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
|
||||
//
|
||||
// Note that the operands and body results need not be compatible as they are
|
||||
// never converted from one to the another nor there is a common source
|
||||
// tensors. Compatibility requirement is not transitive.
|
||||
// tensors. Compatibility requirement is not transitive.
|
||||
|
||||
if (!shape_invariant &&
|
||||
failed(VerifyTypeRangesAreCompatible(op, input_type, result_type)))
|
||||
return failure();
|
||||
|
||||
// Skip the first pair as the While op operands and body function results does
|
||||
// not need to be compatible with each other.
|
||||
for (int i = 1; i < kNumRegionTypeLists; ++i)
|
||||
if (failed(VerifyTypeRangesAreCompatible(op, input_type, region_types[i])))
|
||||
return failure();
|
||||
|
||||
for (int i = 0; i < kNumRegionTypeLists; ++i)
|
||||
if (failed(VerifyTypeRangesAreCompatible(op, result_type, region_types[i])))
|
||||
return failure();
|
||||
|
||||
for (int i = 0; i < kNumRegionTypeLists; ++i)
|
||||
for (int j = i + 1; j < kNumRegionTypeLists; ++j)
|
||||
if (failed(VerifyTypeRangesAreCompatible(op, region_types[i],
|
||||
region_types[j])))
|
||||
return failure();
|
||||
|
||||
for (int i = 0; i < kNumTypeLists; ++i) {
|
||||
// Skip the first pair as the While op operands and body function results
|
||||
// does not need to be compatible with each other.
|
||||
for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) {
|
||||
auto &a = type_lists[i];
|
||||
auto &b = type_lists[j];
|
||||
if (failed(VerifyTypeRangesAreCompatible(op, a, b))) return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -2600,7 +2650,8 @@ static LogicalResult Verify(WhileOp op) {
|
||||
|
||||
if (failed(VerifyWhileTypes(op, /*cond_input=*/cond_fn_type.getInputs(),
|
||||
/*body_input=*/body_fn_type.getInputs(),
|
||||
/*body_result=*/body_fn_type.getResults())))
|
||||
/*body_result=*/body_fn_type.getResults(),
|
||||
op.shape_invariant())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
@ -2625,7 +2676,8 @@ static LogicalResult Verify(WhileRegionOp op) {
|
||||
Operation *body_yield = op.body().front().getTerminator();
|
||||
if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(),
|
||||
/*body_input=*/op.body().getArgumentTypes(),
|
||||
/*body_result=*/body_yield->getOperandTypes())))
|
||||
/*body_result=*/body_yield->getOperandTypes(),
|
||||
op.shape_invariant())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user