Merge remote-tracking branch 'upstream/master' into detection_postprocess

This commit is contained in:
Advait Jain 2020-11-23 16:09:31 -08:00
commit 82eeda304b
629 changed files with 12570 additions and 13992 deletions
ISSUES.mdRELEASE.md
tensorflow
BUILD
c/eager
compiler
jit
mlir
hlo
lite
tensorflow
tfr/examples/mnist
tools/kernel_gen
xla
tf2tensorrt
tf2xla
xla

View File

@ -1,7 +1,9 @@
If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance If you open a GitHub Issue, here is our policy:
issue or a feature request or a build issue or a documentation issue (for small
doc fixes please send a PR instead). 2. Make sure the Issue Template is filled 1. It must be a bug/performance issue or a feature request or a build issue or
out. 3. The issue should be related to the repo it is created in. a documentation issue (for small doc fixes please send a PR instead).
1. Make sure the Issue Template is filled out.
1. The issue should be related to the repo it is created in.
**Here's why we have this policy:** We want to focus on the work that benefits **Here's why we have this policy:** We want to focus on the work that benefits
the whole community, e.g., fixing bugs and adding features. Individual support the whole community, e.g., fixing bugs and adding features. Individual support

View File

@ -54,6 +54,7 @@
* Corrected higher-order gradients of control flow constructs (`tf.cond`, * Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with `tf.while_loop`, and compositions like `tf.foldl`) computed with
`tf.GradientTape` inside a `tf.function`. `tf.GradientTape` inside a `tf.function`.
* Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests.
* `tf.summary`: * `tf.summary`:
* New `tf.summary.graph` allows manual write of TensorFlow graph * New `tf.summary.graph` allows manual write of TensorFlow graph
@ -65,6 +66,19 @@
supported MSVC version to 16.4 (current: 16.8). supported MSVC version to 16.4 (current: 16.8).
* See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion * See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
* TensorRT
* Removed the deprecated `session_config` parameter for the TF1-TRT
converter `TrtGraphConverter`. Previously, we issued a warning when the
value of the parameter is not None.
* The TF2-TRT converter `TrtGraphConverterV2` takes an object of class
TrtConversionParams as a parameter. Removed three deprecated fields from
this class: `rewriter_config_template`, `is_dynamic_op`, and
`max_batch_size`. Previously, we issued a warning when the value of
`rewriter_config_template` is not None. We issued an error when the
value of `is_dynamic_op` is not True. We didn't use the value for
`max_batch_size` for building TensorRT engines.
* Issue a warning when function get_tensorrt_rewriter_config is used.
## Thanks to our Contributors ## Thanks to our Contributors
This release contains contributions from many people at Google, as well as: This release contains contributions from many people at Google, as well as:

View File

@ -72,6 +72,14 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Config setting that disables the default logger, only logging
# to registered TFLogSinks
config_setting(
name = "no_default_logger",
define_values = {"no_default_logger": "true"},
visibility = ["//visibility:public"],
)
# Config setting for determining if we are building for Android. # Config setting for determining if we are building for Android.
config_setting( config_setting(
name = "android", name = "android",
@ -732,6 +740,7 @@ tf_cc_shared_object(
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
"//tensorflow/c:kernels_hdrs", "//tensorflow/c:kernels_hdrs",
"//tensorflow/c:logging",
"//tensorflow/c:ops_hdrs", "//tensorflow/c:ops_hdrs",
"//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/cc/saved_model:loader_lite_impl",
"//tensorflow/core/common_runtime:core_cpu_impl", "//tensorflow/core/common_runtime:core_cpu_impl",

View File

@ -301,7 +301,7 @@ tf_cuda_cc_test(
], ],
args = ["--heap_check=local"], args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(), linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(), tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156
deps = [ deps = [
":c_api_experimental", ":c_api_experimental",
":c_api_unified_internal", ":c_api_unified_internal",
@ -469,6 +469,7 @@ tf_cuda_cc_test(
linkstatic = tf_kernel_tests_linkstatic(), linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [ tags = tf_cuda_tests_tags() + [
"nomac", "nomac",
"no_cuda_asan", # b/173825513
], ],
deps = [ deps = [
":abstract_tensor_handle", ":abstract_tensor_handle",

View File

@ -61,6 +61,7 @@ limitations under the License.
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc. // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed.h"
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
@ -101,7 +102,14 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) { if (opts->use_tfrt) {
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async)); tfrt::tf::ContextInterface* tfrt_context =
new tfrt::tf::ContextInterface(opts->async);
#if !defined(IS_MOBILE_PLATFORM)
tfrt_context->SetDistributedManager(
std::make_unique<tfrt::tf::DistributedManagerContextInterface>(
tfrt_context->GetCoreRuntime()->GetHostContext()));
#endif // !IS_MOBILE_PLATFORM
return tensorflow::wrap(tfrt_context);
#else #else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr; return nullptr;

View File

@ -226,7 +226,7 @@ void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
// Helper functions which delegate to `AbstractOperation`, update // Helper functions which delegate to `AbstractOperation`, update
// the state of the ForwardOperation and call the tape as appropriate. // the state of the ForwardOperation and call the tape as appropriate.
// These APIs are mainly to faciliate testing and are subject to change. // These APIs are mainly to facilitate testing and are subject to change.
namespace internal { namespace internal {
Status Reset(AbstractOperation* op_, const char* op, Status Reset(AbstractOperation* op_, const char* op,
const char* raw_device_name, ForwardOperation* forward_op_) { const char* raw_device_name, ForwardOperation* forward_op_) {

View File

@ -39,7 +39,7 @@ struct XlaAutoJitFlag {
int32 optimization_level_general; int32 optimization_level_general;
}; };
// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax // Sets the xla_auto_jit_flag based on the given flag string. Supported syntax
// is: // is:
// <number>: sets general and single_gpu setting to the provided number. // <number>: sets general and single_gpu setting to the provided number.
// single-gpu(<number>): sets the single_gpu setting to the provided number. // single-gpu(<number>): sets the single_gpu setting to the provided number.

View File

@ -103,7 +103,9 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
if (flr->config_proto()) { if (flr->config_proto()) {
config_proto = *flr->config_proto(); config_proto = *flr->config_proto();
} }
if (!IsMlirBridgePassEnabled(*fbody->graph, config_proto)) { MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(*fbody->graph, config_proto);
if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo> std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>

View File

@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
let results = (outs HLO_StaticShapeTensor); let results = (outs HLO_StaticShapeTensor);
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
} }

View File

@ -202,9 +202,9 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$input, Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch, Arg<LHLO_Buffer, "", [MemWrite]>:$scratch,
Arg<I32Buffer, "", [MemWrite]>:$info, Arg<I32Buffer, "", [MemWrite]>:$info,
BoolAttr:$is_upper); BoolAttr:$is_lower);
} }
#endif // LHLO_GPU_OPS #endif // LHLO_GPU_OPS

View File

@ -197,10 +197,11 @@ def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TODO(b/139813999): specify required function signature in a type-safe way. // TODO(b/139813999): specify required function signature in a type-safe way.
def LHLO_ReduceOp: LHLO_Op<"reduce", [ //
SameVariadicOperandSize, // The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are
SingleBlockImplicitTerminator<"TerminatorOp"> // moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp.
]>, BASE_HLO_ReduceOp { // TODO(timshen): cleanup lmhlo.TerminatorOp.
def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp {
let arguments = (ins let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands, Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values, Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,

View File

@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
return {}; return {};
} }
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Case Op // Case Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat<
def ShapeOfDynamicReshape : Pat< def ShapeOfDynamicReshape : Pat<
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)), (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
(replaceWithValue $shape)>; (replaceWithValue $shape)>;
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
def IdentityBroadcastReshape : Pat<
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
(replaceWithValue $input),
[(HasSameType $input, $op)]>;
def IdentityBroadcastInDimReshape : Pat<
(HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)),
(replaceWithValue $input),
[(HasSameType $input, $op)]>;

View File

@ -61,12 +61,16 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
// mapValues always takes a function returning APInt, even when the output // mapValues always takes a function returning APInt, even when the output
// is actually float. // is actually float.
using func_type = llvm::APInt(const llvm::APInt&); using func_type = llvm::APInt(const llvm::APInt&);
// TODO(hinsu): Correctly handle unsigned element types.
bool is_bool = old_type.isInteger(1);
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) { if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
// Int -> Float // Int -> Float
return elements.mapValues( return elements.mapValues(
new_type, llvm::function_ref<func_type>([&newFloatType]( new_type, llvm::function_ref<func_type>([&newFloatType, &is_bool](
const llvm::APInt& intVal) { const llvm::APInt& intVal) {
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue())); int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
llvm::APFloat newDouble(static_cast<double>(val));
bool loses_info = false; bool loses_info = false;
newDouble.convert(newFloatType.getFloatSemantics(), newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven, &loses_info); llvm::APFloat::rmNearestTiesToEven, &loses_info);
@ -76,9 +80,10 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
// new_type is Integer // new_type is Integer
// Int -> Int // Int -> Int
return elements.mapValues( return elements.mapValues(
new_type, new_type, llvm::function_ref<func_type>([&bit_width, &is_bool](
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) { const llvm::APInt& intVal) {
return llvm::APInt(bit_width, intVal.getSExtValue()); int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
return llvm::APInt(bit_width, val);
})); }));
} }

View File

@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> {
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1] // CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
// CHECK-SAME: ]> : tensor<4x5xi32> // CHECK-SAME: ]> : tensor<4x5xi32>
} }
// CHECK-LABEL: @identity_broadcast_reshape
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
%0 = "mhlo.broadcast"(%arg0) {
broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
return %1 : tensor<128xf32>
// CHECK: return %arg0 : tensor<128xf32>
}
// CHECK-LABEL: @identity_broadcast_in_dim_reshape
func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
%0 = "mhlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32>
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
return %1 : tensor<128xf32>
// CHECK: return %arg0 : tensor<128xf32>
}

View File

@ -123,6 +123,17 @@ func @const_int_bf16() -> tensor<bf16> {
// ----- // -----
// CHECK-LABEL: func @const_bool_f32
func @const_bool_f32() -> tensor<2xf32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xf32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2xf32>
}
// -----
// CHECK-LABEL: func @const_bf16_int // CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i16> { func @const_bf16_int() -> tensor<i16> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16> // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
@ -145,8 +156,8 @@ func @const_int_narrowing() -> tensor<i32> {
// ----- // -----
// CHECK-LABEL: func @const_int_widening // CHECK-LABEL: func @const_bool_widening
func @const_int_widening() -> tensor<i64> { func @const_bool_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64> // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
%cst = mhlo.constant dense<42> : tensor<i32> %cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64> %0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
@ -156,6 +167,17 @@ func @const_int_widening() -> tensor<i64> {
// ----- // -----
// CHECK-LABEL: func @const_int_widening
func @const_int_widening() -> tensor<2xi32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0, 1]> : tensor<2xi32>
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2xi32>
}
// -----
// CHECK-LABEL: func @const_negative_int_widening // CHECK-LABEL: func @const_negative_int_widening
func @const_negative_int_widening() -> tensor<i64> { func @const_negative_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64> // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>

View File

@ -93,7 +93,7 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
%scratch = alloc() : memref<32xi8> %scratch = alloc() : memref<32xi8>
%info = alloc() : memref<32xi32> %info = alloc() : memref<32xi32>
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true } "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> () : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
return return
} }

View File

@ -704,6 +704,7 @@ cc_library(
":convert_type", ":convert_type",
":flatbuffer_tflite_operator_lib", ":flatbuffer_tflite_operator_lib",
":tensorflow_lite", ":tensorflow_lite",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",

View File

@ -64,6 +64,7 @@ limitations under the License.
#include "mlir/Translation.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@ -149,7 +150,7 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
int64_t storage_min = QuantizedType::getDefaultMinimumForInteger( int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
is_signed, storage_type.getWidth()) + is_signed, storage_type.getWidth()) +
is_weight_buffer; static_cast<int>(is_weight_buffer);
int64_t storage_max = QuantizedType::getDefaultMaximumForInteger( int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
is_signed, storage_type.getWidth()); is_signed, storage_type.getWidth());
uint32_t flags = uint32_t flags =
@ -177,12 +178,25 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
quant_params.zero_point.at(0), storage_min, storage_max); quant_params.zero_point.at(0), storage_min, storage_max);
} }
// import float tensor with calibration value into calibrated quantized type.
StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
Builder builder) {
if (tensor.quantization == nullptr) {
return errors::InvalidArgument("The tensor is not quantized.");
}
auto raw_elem_type = ConvertElementType(tensor.type, builder);
float min = tensor.quantization->min[0];
float max = tensor.quantization->max[0];
return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
}
// TODO(b/138222071) Remove shapeless_are_scalars once we can reliably // TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
// make that distinction and don't have to rely on context // make that distinction and don't have to rely on context
// (input to main and constants must have static shape) // (input to main and constants must have static shape)
StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder, StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
bool shapeless_are_scalars = false, bool shapeless_are_scalars = false,
bool is_constant = false) { bool is_constant = false,
bool is_intermediate = false) {
mlir::Type elem_type = ConvertElementType(tensor.type, builder); mlir::Type elem_type = ConvertElementType(tensor.type, builder);
// TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere // TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere
// if it's set // if it's set
@ -191,6 +205,13 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
GetQuantizedType(tensor, builder, is_constant)); GetQuantizedType(tensor, builder, is_constant));
} }
// Intermediate tensors with calibration value (but not scale and zero points)
// should return calibrated quantized type.
if (is_intermediate && tensor.quantization != nullptr &&
!IsQuantized(tensor)) {
TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
}
if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) { if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
return RankedTensorType::get({}, elem_type); return RankedTensorType::get({}, elem_type);
} }
@ -1033,7 +1054,7 @@ StatusOr<FuncOp> ConvertSubgraph(
} }
} }
// Intermediate tensors for tfl.lstm are used to carry quantization range // Intermediate tensors for LSTMs are used to carry quantization range
// in their types, so we only need and extract their types. // in their types, so we only need and extract their types.
std::vector<mlir::TensorType> intermediate_types; std::vector<mlir::TensorType> intermediate_types;
intermediate_types.reserve(5); intermediate_types.reserve(5);
@ -1041,7 +1062,8 @@ StatusOr<FuncOp> ConvertSubgraph(
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto type, GetTensorType(*subgraph.tensors[intermediate], builder, auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
/*shapeless_are_scalars=*/true, /*shapeless_are_scalars=*/true,
/*is_constant=*/true)); /*is_constant=*/false,
/*is_intermediate=*/true));
intermediate_types.emplace_back(type); intermediate_types.emplace_back(type);
} }
@ -1135,7 +1157,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
auto builder = Builder(context); auto builder = Builder(context);
std::vector<std::string> func_names; std::vector<std::string> func_names;
for (auto& subgraph : model->subgraphs) { for (auto& subgraph : model->subgraphs) {
func_names.push_back(subgraph->name); func_names.push_back(subgraph->name);

View File

@ -1978,6 +1978,10 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1); assert(operands.size() == 1);
if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
return input();
}
// For now, only supports cast between integer types. // For now, only supports cast between integer types.
auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!elements_attr) { if (!elements_attr) {

View File

@ -483,9 +483,9 @@ value of each element in `x`. For example, if x is an input element and y is
an output element, this operation computes \\(y = |x|\\). an output element, this operation computes \\(y = |x|\\).
}]; }];
let arguments = (ins TFL_FpTensor:$x); let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x);
let results = (outs TFL_FpTensor:$y); let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y);
let hasFolder = 1; let hasFolder = 1;
} }
@ -587,15 +587,15 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
let arguments = (ins let arguments = (ins
TFL_I32Tensor:$output_shape, TFL_I32Tensor:$output_shape,
TFL_TensorOf<[F32, QI8, QUI8]>:$weights, TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$weights,
TFL_TensorOf<[F32, QI8, QUI8]>:$input, TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
TFL_TensorOfOrNone<[F32, QI32]>:$bias, TFL_TensorOfOrNone<[F32, QI32, I64]>:$bias,
TFL_PaddingAttr:$padding, TFL_PaddingAttr:$padding,
Confined<I32Attr, [IntPositive]>:$stride_h, Confined<I32Attr, [IntPositive]>:$stride_h,
Confined<I32Attr, [IntPositive]>:$stride_w Confined<I32Attr, [IntPositive]>:$stride_w
); );
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
let hasOptions = 1; let hasOptions = 1;
@ -624,7 +624,7 @@ def TFL_AveragePool2DOp:
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input, ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
I32Attr:$filter_height, I32Attr:$filter_height,
I32Attr:$filter_width, I32Attr:$filter_width,
TFL_PaddingAttr:$padding, TFL_PaddingAttr:$padding,
@ -633,7 +633,7 @@ def TFL_AveragePool2DOp:
TFL_AFAttr:$fused_activation_function TFL_AFAttr:$fused_activation_function
); );
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
let hasOptions = 1; let hasOptions = 1;
let customOption = "Pool2DOptions"; let customOption = "Pool2DOptions";
@ -947,7 +947,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input, TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input,
TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter, TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$filter,
TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
TFL_AFAttr:$fused_activation_function, TFL_AFAttr:$fused_activation_function,
@ -999,14 +999,14 @@ in the batch dimensions and broadcasting.
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, QI8]>:$x, TFL_TensorOf<[F32, QI8, QI16]>:$x,
TFL_TensorOf<[F32, QI8]>:$y, TFL_TensorOf<[F32, QI8, QI16]>:$y,
DefaultValuedAttr<BoolAttr, "false">:$adj_x, DefaultValuedAttr<BoolAttr, "false">:$adj_x,
DefaultValuedAttr<BoolAttr, "false">:$adj_y DefaultValuedAttr<BoolAttr, "false">:$adj_y
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, QI8]>:$output TFL_TensorOf<[F32, QI8, QI16]>:$output
); );
let hasOptions = 1; let hasOptions = 1;
@ -1026,7 +1026,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$params, TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$params,
TFL_TensorOf<[I32, I64]>:$indices, TFL_TensorOf<[I32, I64]>:$indices,
I32Attr:$axis I32Attr:$axis
); );
@ -1038,7 +1038,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
]; ];
let results = (outs let results = (outs
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$output TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$output
); );
let hasOptions = 1; let hasOptions = 1;
@ -1750,12 +1750,12 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input, ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$input,
// Slope of the activation function at x < 0. // Slope of the activation function at x < 0.
F32Attr:$alpha F32Attr:$alpha
); );
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output); let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 0b1; let hasOptions = 0b1;
} }
@ -1977,12 +1977,12 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs, ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$max TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$max
); );
let builders = [TFL_BroadcastableBinaryBuilder]; let builders = [TFL_BroadcastableBinaryBuilder];
@ -2005,13 +2005,13 @@ def TFL_MeanOp : TFL_Op<"mean", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$input, TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$input,
TFL_TensorOf<[I32, I64]>:$axis, TFL_TensorOf<[I32, I64]>:$axis,
BoolAttr:$keep_dims BoolAttr:$keep_dims
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$output); TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$output);
let hasOptions = 1; let hasOptions = 1;
let customOption = "ReducerOptions"; let customOption = "ReducerOptions";
@ -2090,13 +2090,13 @@ equivalent to setting:
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$input, TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32OrI64Tensor:$begin, TFL_I32OrI64Tensor:$begin,
TFL_I32OrI64Tensor:$size TFL_I32OrI64Tensor:$size
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$output TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output
); );
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
@ -2116,12 +2116,12 @@ def TFL_SumOp: TFL_Op<"sum", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32Tensor:$axes, TFL_I32Tensor:$axes,
BoolAttr:$keep_dims BoolAttr:$keep_dims
); );
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1; let hasOptions = 1;
let customOption = "ReducerOptions"; let customOption = "ReducerOptions";
@ -2139,13 +2139,13 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32Tensor:$axes, TFL_I32Tensor:$axes,
BoolAttr:$keep_dims BoolAttr:$keep_dims
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1; let hasOptions = 1;
let customOption = "ReducerOptions"; let customOption = "ReducerOptions";
@ -2163,13 +2163,13 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32Tensor:$axes, TFL_I32Tensor:$axes,
BoolAttr:$keep_dims BoolAttr:$keep_dims
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1; let hasOptions = 1;
let customOption = "ReducerOptions"; let customOption = "ReducerOptions";
@ -2186,13 +2186,13 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32Tensor:$axes, TFL_I32Tensor:$axes,
BoolAttr:$keep_dims BoolAttr:$keep_dims
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1; let hasOptions = 1;
let customOption = "ReducerOptions"; let customOption = "ReducerOptions";
@ -2210,12 +2210,12 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs, ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$min TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$min
); );
let builders = [TFL_BroadcastableBinaryBuilder]; let builders = [TFL_BroadcastableBinaryBuilder];
@ -2364,10 +2364,10 @@ def TFL_PadOp : TFL_Op<"pad", [
``` ```
}]; }];
let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32OrI64Tensor:$padding); TFL_I32OrI64Tensor:$padding);
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1; let hasOptions = 1;
} }
@ -2500,9 +2500,9 @@ def TFL_ReluOp: TFL_Op<"relu", [
x -> max(0, x) x -> max(0, x)
}]; }];
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x); let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y); let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$y);
// This builder doesn't work with quantized type, so it can only be used by // This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the // non-quantization tablegen patterns. Currently, it is used by the
@ -2828,11 +2828,11 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input, ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$input,
F32Attr:$beta F32Attr:$beta
); );
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output); let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1; let hasOptions = 1;
@ -3058,12 +3058,12 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$input, TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$input,
TFL_TensorOf<[I32]>:$perm TFL_TensorOf<[I32]>:$perm
); );
let results = (outs let results = (outs
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$output TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$output
); );
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
@ -3330,14 +3330,14 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$input, TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$input,
TFL_I32Tensor:$size, TFL_I32Tensor:$size,
BoolAttr:$align_corners, BoolAttr:$align_corners,
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$output TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$output
); );
let hasOptions = 1; let hasOptions = 1;
@ -3830,6 +3830,9 @@ Ba et al. 'Layer Normalization'
// Types of the optional intermediate tensors, which exist for fully // Types of the optional intermediate tensors, which exist for fully
// quantized LSTM op and hold the ranges of the intermediate tensors. // quantized LSTM op and hold the ranges of the intermediate tensors.
// The type for intermediate tenssors are be quant.calibrated when imported
// to only store calibrated min, max values. The proper quantization spec is
// determined while going through quantization passes.
OptionalAttr<TypeAttr>:$input_to_input_intermediate, OptionalAttr<TypeAttr>:$input_to_input_intermediate,
OptionalAttr<TypeAttr>:$input_to_forget_intermediate, OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
OptionalAttr<TypeAttr>:$input_to_cell_intermediate, OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
@ -3945,6 +3948,9 @@ def TFL_UnidirectionalSequenceLSTMOp :
// Types of the optional intermediate tensors, which exist for fully // Types of the optional intermediate tensors, which exist for fully
// quantized op and hold the ranges of the intermediate tensors. // quantized op and hold the ranges of the intermediate tensors.
// The type for intermediate tenssors are be quant.calibrated when imported
// to only store calibrated min, max values. The proper quantization spec is
// determined while going through quantization passes.
OptionalAttr<TypeAttr>:$input_to_input_intermediate, OptionalAttr<TypeAttr>:$input_to_input_intermediate,
OptionalAttr<TypeAttr>:$input_to_forget_intermediate, OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
OptionalAttr<TypeAttr>:$input_to_cell_intermediate, OptionalAttr<TypeAttr>:$input_to_cell_intermediate,

View File

@ -89,6 +89,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true; pass_config.lower_tensor_list_ops = true;
// Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
// conversion to an unsupported 16x16 TFL::FullyConnectedOp.
if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
pass_config.unfold_batch_matmul = false;
}
return internal::ConvertMLIRToTFLiteFlatBuffer( return internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{}, toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},

View File

@ -174,6 +174,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true; pass_config.lower_tensor_list_ops = true;
// Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
// conversion to an unsupported 16x16 TFL::FullyConnectedOp.
if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
pass_config.unfold_batch_matmul = false;
}
// TODO(b/153507667): Pass the session object when importing logic is removed. // TODO(b/153507667): Pass the session object when importing logic is removed.
auto status = internal::ConvertMLIRToTFLiteFlatBuffer( auto status = internal::ConvertMLIRToTFLiteFlatBuffer(

View File

@ -53,10 +53,11 @@ struct QuantizationSpecs {
bool disable_per_channel = false; bool disable_per_channel = false;
// When set to true, the fixed output ranges of the activation ops (tanh, // When set to true, the fixed output ranges of the activation ops (tanh,
// sigmoid, etc.) are not enforced. Then, to quantize these ops, quantization // sigmoid, etc.) and the weight constants are not inferred. Then, to quantize
// emulation ops should be specified after the ops in the input graph. This // these ops, quantization emulation ops should be placed after the ops in the
// flag should be set to false for post-training quantization. // input graph. This flag should be set to false for post-training
bool disable_enforced_fixed_output_range = false; // quantization.
bool disable_infer_tensor_range = false;
// The node type when the model is exported. Currently this is limited to // The node type when the model is exported. Currently this is limited to
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the

View File

@ -100,13 +100,13 @@ class QuantizationDriver {
explicit QuantizationDriver(FuncOp fn, bool is_signed, explicit QuantizationDriver(FuncOp fn, bool is_signed,
bool disable_per_channel, bool disable_per_channel,
OpQuantSpecGetter op_quant_spec_getter, OpQuantSpecGetter op_quant_spec_getter,
bool enforce_fixed_output_range) bool infer_tensor_range)
: fn_(fn), : fn_(fn),
builder_(fn.getBody()), builder_(fn.getBody()),
is_signed_(is_signed), is_signed_(is_signed),
disable_per_channel_(disable_per_channel), disable_per_channel_(disable_per_channel),
op_quant_spec_getter_(op_quant_spec_getter), op_quant_spec_getter_(op_quant_spec_getter),
enforce_fixed_output_range_(enforce_fixed_output_range) {} infer_tensor_range_(infer_tensor_range) {}
// The entry point of the quantization parameters propagation. // The entry point of the quantization parameters propagation.
void Run(); void Run();
@ -384,7 +384,9 @@ class QuantizationDriver {
OpQuantSpecGetter op_quant_spec_getter_; OpQuantSpecGetter op_quant_spec_getter_;
bool enforce_fixed_output_range_; // Infer output ranges for activation ops and constants. This is usually
// required for post-training quantization.
bool infer_tensor_range_;
}; };
} // namespace } // namespace
@ -670,33 +672,43 @@ void QuantizationDriver::PreprocessConstantOps() {
Value value = cst.getResult(); Value value = cst.getResult();
builder_.setInsertionPoint(cst); builder_.setInsertionPoint(cst);
for (auto indexed_use : llvm::enumerate(value.getUses())) {
auto &use = indexed_use.value();
auto spec = GetQuantSpec(use.getOwner());
auto biases = spec->biases_params;
Operation *user = use.getOwner();
int operand_num = use.getOperandNumber();
// The following loop will change the value uses, thus we cache all the uses
// needs to be changed.
llvm::SmallVector<std::pair<Operation *, int>, 4> uses;
for (auto &use : value.getUses()) {
uses.push_back({use.getOwner(), use.getOperandNumber()});
}
for (auto indexed_use : llvm::enumerate(uses)) {
Operation *user = indexed_use.value().first;
int operand_num = indexed_use.value().second;
auto spec = GetQuantSpec(user);
auto biases = spec->biases_params;
// The quantization parameters of a `weight` shouldn't be determined by
// other values. So any constants which are not bias, an operand of an
// op with same scale requirements, and haven't been quantized are
// weights.
if (biases.find(operand_num) == biases.end() && if (biases.find(operand_num) == biases.end() &&
!llvm::dyn_cast<mlir::SameScalesOpInterface>(user) && !llvm::dyn_cast<mlir::SameScalesOpInterface>(user) &&
!llvm::dyn_cast<quant::QuantizeCastOp>(user)) { !llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
// Needs to scan the content to get the quantization parameters if there // Needs to scan the content of weights to get the quantization
// are no quantization parameters (FakeQuant ops). // parameters if there are no quantization parameters (FakeQuant ops).
// For this case, the weight isn't duplicated. // For this case, the weight will not be duplicated.
weights_.insert(cst); weights_.insert(cst);
auto affine_user = auto affine_user =
llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user); llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
if (affine_user && if (affine_user && affine_user.GetAffineOperandIndex() == operand_num &&
affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
affine_user.RequiredNarrowRangeAffineOperand()) { affine_user.RequiredNarrowRangeAffineOperand()) {
optimized_weights_.insert( optimized_weights_.insert(
{cst, affine_user.GetQuantizationDimIndex()}); {cst, affine_user.GetQuantizationDimIndex()});
} }
} else { } else {
// This is a bias, so the quantization parameter isn't determined by the // This is a bias or an operand of an op with same scale requirements,
// local content. Same if the user can have quantization parameter // so the quantization parameter are propagated from or determined by
// propagated from other places. // other values. Duplicate this constant in case it is shared by
// Duplicate this constant in case it is shared by different users. // different users.
if (indexed_use.index() > 0) { if (indexed_use.index() > 0) {
cst = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue()); cst = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
} }
@ -786,12 +798,14 @@ bool QuantizationDriver::PropagateParams() {
quantized_.insert(op); quantized_.insert(op);
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) { if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
// If it isn't a weight or has been quantized, skip. // If the workflow requires inferring ranges from the content
if (!IsWeight(cst) || IsQuantized(op)) continue; // (post-training quantization) and it is weight (filter) and hasn't
// been quantized, we infer the quantization parameters from the content.
// The quantization parameters are determined by the content of the if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) {
// constant. // The quantization parameters are determined by the content of the
changed |= SetConstantResultParams(op); // constant.
changed |= SetConstantResultParams(op);
}
continue; continue;
} }
@ -826,7 +840,9 @@ bool QuantizationDriver::PropagateParams() {
// TODO(fengliuai): make the bit width configurable. // TODO(fengliuai): make the bit width configurable.
auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op); auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
if (restricted && enforce_fixed_output_range_) { if (restricted && infer_tensor_range_) {
// Infer ranges from the activation ops. This is usually required for
// the post-training quantization workflow.
// TODO(fengliuai): different result can have different fixed range. // TODO(fengliuai): different result can have different fixed range.
auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8); auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
for (auto i = 0; i < op->getNumResults(); ++i) { for (auto i = 0; i < op->getNumResults(); ++i) {
@ -903,9 +919,9 @@ void QuantizationDriver::Run() {
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
bool disable_per_channel, bool disable_per_channel,
OpQuantSpecGetter op_quant_spec_getter, OpQuantSpecGetter op_quant_spec_getter,
bool post_training_quantization) { bool infer_tensor_ranges) {
QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter, QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
post_training_quantization) infer_tensor_ranges)
.Run(); .Run();
} }

View File

@ -490,13 +490,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
// and the propagation results are materialized by inserting pairs of quantize // and the propagation results are materialized by inserting pairs of quantize
// and dequantize ops to this function. Set `disable_per_channel` to true to not // and dequantize ops to this function. Set `disable_per_channel` to true to not
// use per channel quantization even the op supports it. // use per channel quantization even the op supports it.
// Setting `enforce_fixed_output_range` to true, to infer quantization // Setting `infer_tensor_range` to true, to infer quantization parameters from
// parameters from the fixed output range ops. This is only used for // the activation ops and weight constants. This is only used for post-training
// post-training quantization. // quantization.
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
bool disable_per_channel, bool disable_per_channel,
OpQuantSpecGetter op_quant_spec_getter, OpQuantSpecGetter op_quant_spec_getter,
bool enforce_fixed_output_range); bool infer_tensor_ranges);
// The function might contain more stats ops than required, and it will // The function might contain more stats ops than required, and it will
// introduce requantize if the calibration stats have conflicts. This method // introduce requantize if the calibration stats have conflicts. This method

View File

@ -638,4 +638,10 @@ func @cast_ui8_to_i32() -> tensor<4xi32> {
// CHECK: return %[[CST]] // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @cast_identity
func @cast_identity(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
%0 = "tfl.cast"(%arg0) : (tensor<7xf32>) -> tensor<7xf32>
return %0 : tensor<7xf32>
// CHECK: return %arg0 : tensor<7xf32>
}

View File

@ -0,0 +1,335 @@
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
// CHECK: effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01, 5.000000e-01>>>
// CHECK: input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00, 4.000000e+00>>>
// CHECK: input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01, 1.600000e+01>>>
// CHECK: input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01, 3.200000e+01>>>
// CHECK: input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00, 1.000000e+00>>>
{
"version": 3,
"operator_codes": [
{
"builtin_code": "UNIDIRECTIONAL_SEQUENCE_LSTM"
}
],
"subgraphs": [
{
"tensors": [
{
"shape": [1, 5],
"name": "input0"
},
{
"shape": [2, 5],
"buffer": 1,
"name": "input2input_weights1"
},
{
"shape": [2, 5],
"buffer": 2,
"name": "input2forget_weights2"
},
{
"shape": [2, 5],
"buffer": 3,
"name": "input2cell_weights3"
},
{
"shape": [2, 5],
"buffer": 4,
"name": "input2output_weights4"
},
{
"shape": [2, 4],
"buffer": 5,
"name": "rec2input_weights5"
},
{
"shape": [2, 4],
"buffer": 6,
"name": "rec2forget_weights6"
},
{
"shape": [2, 4],
"buffer": 7,
"name": "rec2cell_weights7"
},
{
"shape": [2, 4],
"buffer": 8,
"name": "rec2output_weights8"
},
{
"shape": [2],
"buffer": 9,
"name": "cell2input_weights9"
},
{
"shape": [2],
"buffer": 10,
"name": "cell2forget_weights10"
},
{
"shape": [2],
"buffer": 11,
"name": "cell2output_weights11"
},
{
"shape": [2],
"buffer": 12,
"name": "input_gate_bias12"
},
{
"shape": [2],
"buffer": 13,
"name": "forget_gate_bias13"
},
{
"shape": [2],
"buffer": 14,
"name": "cell_gate_bias14"
},
{
"shape": [2],
"buffer": 15,
"name": "output_gate_bias15"
},
{
"shape": [4, 2],
"buffer": 16,
"name": "proj_weights16"
},
{
"shape": [4],
"buffer": 17,
"name": "proj_bias17"
},
{
"shape": [1, 4],
"name": "input_activation_state18",
"is_variable": true,
"quantization": {
"min": [-0.9],
"max": [0.9]
}
},
{
"shape": [1, 2],
"name": "input_cell_state19",
"is_variable": true,
"quantization": {
"min": [-0.8],
"max": [0.8]
}
},
{
"shape": [2],
"buffer": 18,
"name": "input_norm20"
},
{
"shape": [2],
"buffer": 19,
"name": "forget_norm21"
},
{
"shape": [2],
"buffer": 20,
"name": "cell_norm22"
},
{
"shape": [2],
"buffer": 21,
"name": "output_norm23"
},
{
"shape": [],
"name": "output24"
},
{
"shape": [],
"name": "intermediate_0",
"is_variable": true,
"quantization": {
"min": [-32],
"max": [32]
}
},
{
"shape": [],
"name": "intermediate_1",
"is_variable": true,
"quantization": {
"min": [-16],
"max": [16]
}
},
{
"shape": [],
"name": "intermediate_2",
"is_variable": true,
"quantization": {
"min": [-4],
"max": [4]
}
},
{
"shape": [],
"name": "intermediate_3",
"is_variable": true,
"quantization": {
"min": [-1.0],
"max": [1.0]
}
},
{
"shape": [],
"name": "intermediate_4",
"is_variable": true,
"quantization": {
"min": [-0.5],
"max": [0.5]
}
}
],
"inputs": [0],
"outputs": [24],
"operators": [
{
"inputs": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23
],
"outputs": [24],
"intermediates": [
25, 26, 27, 28, 29
],
"builtin_options_type": "UnidirectionalSequenceLSTMOptions",
"builtin_options": {
"fused_activation_function": "TANH",
"cell_clip": 50.0
},
"mutating_variable_inputs": [
false,
false, false, false, false,
false, false, false, false,
false, false, false,
false, false, false, false,
true, true,
false, false, false, false
]
}
]
}
],
"buffers": [
{
"data": []
},
{
"data": [
36, 167, 168, 63, 0, 140, 72, 191, 120, 20, 147, 62, 20, 152, 196, 190, 121, 98, 82, 187, 95, 128, 213, 61, 189, 3, 138, 63, 54, 103, 13, 62, 46, 224, 66, 63, 157, 204, 180, 191
]
},
{
"data": [
223, 20, 21, 64, 246, 166, 31, 191, 6, 51, 157, 188, 114, 90, 167, 62, 118, 240, 59, 63, 49, 162, 255, 62, 17, 91, 160, 63, 32, 47, 26, 63, 40, 136, 178, 191, 243, 154, 236, 61
]
},
{
"data": [
137, 231, 86, 63, 41, 154, 16, 63, 239, 37, 77, 191, 55, 189, 24, 189, 86, 63, 18, 63, 42, 55, 13, 191, 110, 139, 138, 191, 219, 148, 181, 63, 71, 232, 108, 191, 66, 226, 145, 191
]
},
{
"data": [
245, 179, 225, 190, 51, 202, 176, 189, 132, 47, 53, 191, 155, 25, 50, 191, 197, 130, 240, 191, 98, 125, 45, 62, 243, 70, 83, 62, 85, 155, 139, 63, 113, 239, 11, 192, 35, 251, 139, 62
]
},
{
"data": [
248, 188, 211, 191, 142, 11, 73, 62, 36, 8, 84, 63, 186, 208, 11, 191, 76, 208, 190, 191, 223, 200, 210, 63, 183, 170, 103, 63, 116, 129, 145, 63
]
},
{
"data": [
235, 202, 222, 190, 159, 201, 112, 191, 217, 248, 166, 63, 165, 199, 131, 191, 130, 59, 47, 63, 179, 11, 186, 62, 55, 168, 18, 192, 152, 213, 26, 64
]
},
{
"data": [
245, 123, 138, 62, 213, 106, 231, 59, 211, 218, 250, 62, 25, 157, 134, 63, 147, 22, 164, 63, 25, 221, 139, 62, 1, 230, 247, 62, 210, 185, 142, 63
]
},
{
"data": [
197, 123, 23, 192, 45, 96, 178, 190, 174, 87, 165, 62, 213, 225, 200, 191, 119, 248, 15, 191, 128, 125, 171, 189, 90, 125, 222, 63, 4, 76, 95, 62
]
},
{
"data": [
210, 73, 183, 63, 248, 177, 13, 191
]
},
{
"data": [
78, 251, 212, 191, 169, 29, 147, 63
]
},
{
"data": [
178, 227, 203, 191, 247, 155, 103, 63
]
},
{
"data": [
206, 111, 165, 190, 153, 77, 227, 63
]
},
{
"data": [
255, 114, 132, 191, 253, 202, 140, 191
]
},
{
"data": [
90, 247, 1, 192, 125, 120, 209, 191
]
},
{
"data": [
65, 75, 243, 191, 58, 122, 146, 190
]
},
{
"data": [
40, 135, 20, 63, 109, 50, 220, 191, 56, 241, 189, 63, 65, 12, 92, 63, 61, 14, 162, 62, 157, 138, 81, 63, 125, 61, 191, 61, 102, 231, 20, 63
]
},
{
"data": [
145, 79, 49, 189, 175, 235, 220, 190, 182, 111, 157, 190, 144, 236, 97, 191
]
},
{
"data": [
76, 188, 109, 63, 228, 150, 201, 190
]
},
{
"data": [
6, 146, 66, 191, 122, 127, 100, 191
]
},
{
"data": [
216, 59, 169, 190, 161, 178, 215, 191
]
},
{
"data": [
208, 144, 101, 191, 127, 233, 195, 190
]
}
]
}

View File

@ -20,20 +20,20 @@ func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32
func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> { func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
%cst = constant unit %cst = constant unit
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
// CHECK-LABEL: testFullyQuantizedLSTM // CHECK-LABEL: testFullyQuantizedLSTM
// CHECK: %cst = constant unit // CHECK: %cst = constant unit
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) // CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18)
// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> // CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
} }
// ----- // -----
// CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates // CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }

View File

@ -318,6 +318,15 @@ func @any(%arg0: tensor<2x2xi1>, %arg1: tensor<i32>) -> tensor<i1> {
// CHECK: "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1> // CHECK: "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
} }
func @any_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tensor<?xi1> {
%0 = "tf.Any"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi64>) -> tensor<?xi1>
return %0 : tensor<?xi1>
// CHECK-LABEL: any_i64axes
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.reduce_any"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor<?xi1>
}
func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> %0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32> return %0 : tensor<8x16xf32>
@ -972,6 +981,15 @@ func @sum_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32
// CHECK: "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> // CHECK: "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
} }
func @sum_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
%0 = "tf.Sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: sum_i64axes
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.sum"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> { func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> %0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
@ -988,6 +1006,15 @@ func @reduce_min_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso
// CHECK: "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> // CHECK: "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
} }
func @reduce_min_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: reduce_min_i64axes
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.reduce_min"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> { func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> %0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
@ -1004,6 +1031,15 @@ func @reduce_max_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso
// CHECK: "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> // CHECK: "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
} }
func @reduce_max_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: reduce_max_i64axes
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.reduce_max"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_prod(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> { func @reduce_prod(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> %0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
@ -1020,6 +1056,15 @@ func @reduce_prod_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tens
// CHECK: "tfl.reduce_prod"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32> // CHECK: "tfl.reduce_prod"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
} }
func @reduce_prod_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
%0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: reduce_prod_i64axes
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.reduce_prod"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> { func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32> %0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
@ -1172,10 +1217,10 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
return %0 : tensor<10x10xf32> return %0 : tensor<10x10xf32>
// CHECK-LABEL: strided_slice_with_constant_attributes // CHECK-LABEL: strided_slice_with_constant_attributes
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32>
// CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32> // CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32>
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32> // CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32>
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32> // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
} }
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> { func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
@ -1185,6 +1230,39 @@ func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tenso
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
} }
func @strided_slice_with_unranked_input_and_i64_parameters(%arg0: tensor<*xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<*xf32> {
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: strided_slice_with_unranked_input_and_i64_parameters
// CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<*xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xf32>
}
func @strided_slice_with_i64_parameters(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<1x2x2x5xf32> {
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1x2x2x5xf32>
return %0 : tensor<1x2x2x5xf32>
// CHECK-LABEL: strided_slice_with_i64_parameters
// CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
}
func @strided_slice_with_i64_constant_attributes(%arg0: tensor<10x10x10xf32>) -> tensor<10x10xf32> {
%cst = constant dense<-1> : tensor<1xi64>
%cst_1 = constant dense<0> : tensor<1xi64>
%cst_2 = constant dense<1> : tensor<1xi64>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<10x10xf32>
return %0 : tensor<10x10xf32>
// CHECK-LABEL: strided_slice_with_i64_constant_attributes
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32>
// CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32>
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32>
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
}
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> { func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32> %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32> return %0 : tensor<?x3x5xf32>

View File

@ -625,6 +625,35 @@ func @QuantizeSharedBiases2(
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
} }
// Make sure constants are duplicataed for all users.
// CHECK-LABEL: QuantizeSharedConstantsMultipleUsers
func @QuantizeSharedConstantsMultipleUsers(
%arg0: tensor<32x!quant.uniform<u8:f32, 1.0>>,
%arg1: tensor<32x!quant.uniform<u8:f32, 2.0>>,
%arg2: tensor<32x!quant.uniform<u8:f32, 3.0>>,
%arg3: tensor<32x!quant.uniform<u8:f32, 4.0>>) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) {
%cst = constant dense<0.0> : tensor<32xf32>
%0 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform<u8:f32, 1.0>>) -> tensor<32xf32>
%1 = "tfl.dequantize"(%arg1) : (tensor<32x!quant.uniform<u8:f32, 2.0>>) -> tensor<32xf32>
%2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<u8:f32, 3.0>>) -> tensor<32xf32>
%3 = "tfl.dequantize"(%arg3) : (tensor<32x!quant.uniform<u8:f32, 4.0>>) -> tensor<32xf32>
%4 = "tfl.minimum"(%0, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
%5 = "tfl.minimum"(%1, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
%6 = "tfl.minimum"(%2, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
%7 = "tfl.minimum"(%3, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
return %4, %5, %6, %7 : tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>
// CHECK: %[[cst1:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<32xf32>
// CHECK: %[[cst2:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 3.000000e+00>>) -> tensor<32xf32>
// CHECK: %[[cst3:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 4.000000e+00>>) -> tensor<32xf32>
// CHECK: %[[cst4:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<32xf32>
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst4]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst1]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst2]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst3]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
}
// Make sure quantization parameters are scanned from weight, but not from bias. // Make sure quantization parameters are scanned from weight, but not from bias.
// CHECK-LABEL: QuantizeWeight // CHECK-LABEL: QuantizeWeight
func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {

View File

@ -551,6 +551,19 @@ func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf
return %0 : tensor<8x4x16x1xf32> return %0 : tensor<8x4x16x1xf32>
} }
func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<10x10xf32> {
%cst = constant dense<-1> : tensor<1xi32>
%cst_1 = constant dense<0> : tensor<1xi32>
%cst_2 = constant dense<1> : tensor<1xi32>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
return %0 : tensor<10x10xf32>
// CHECK-LABEL: strided_slice_with_constant_attributes
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32>
// CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32>
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32>
// CHECK-NEXT: "tf.StridedSlice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
}
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32> return %0: tensor<3x3xf32>

View File

@ -328,23 +328,28 @@ def LegalizeMean : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2),
(TFL_MeanOp $arg0, $arg1, $arg2)>; (TFL_MeanOp $arg0, $arg1, $arg2)>;
def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2),
(TFL_SumOp $arg, $axes, $arg2)>; (TFL_SumOp $arg, (CreateTFCastToInt32Op $axes), $arg2)>;
// TopK in TFL is always sorted so we ignore that attribute here. // TopK in TFL is always sorted so we ignore that attribute here.
def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted),
(TFL_TopKV2Op $input, $k)>; (TFL_TopKV2Op $input, $k)>;
def LegalizeMin : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), def LegalizeMin : Pat<
(TFL_ReduceMinOp $arg0, $arg1, $arg2)>; (TF_MinOp $arg0, $axes, BoolAttr:$arg2),
(TFL_ReduceMinOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
def LegalizeMax : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), def LegalizeMax : Pat<
(TFL_ReduceMaxOp $arg0, $arg1, $arg2)>; (TF_MaxOp $arg0, $axes, BoolAttr:$arg2),
(TFL_ReduceMaxOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
def LegalizeProd : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), def LegalizeProd : Pat<
(TFL_ReduceProdOp $arg0, $arg1, $arg2)>; (TF_ProdOp $arg0, $axes, BoolAttr:$arg2),
(TFL_ReduceProdOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
def LegalizeAny : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims), def LegalizeAny : Pat<
(TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>; (TF_AnyOp $input, $reduction_indices, $keep_dims),
(TFL_ReduceAnyOp $input, (CreateTFCastToInt32Op $reduction_indices),
$keep_dims)>;
def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>; def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
@ -471,3 +476,14 @@ def LegalizeCumsum : Pat<
def LegalizeReshape : Pat< def LegalizeReshape : Pat<
(TF_ReshapeOp $input, $shape), (TF_ReshapeOp $input, $shape),
(TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>; (TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>;
def LegalizeStridedSlice : Pat<
(TF_StridedSliceOp
$input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask,
$new_axis_mask, $shrink_axis_mask),
(TFL_StridedSliceOp $input,
(CreateTFCastToInt32Op $begin), (CreateTFCastToInt32Op $end),
(CreateTFCastToInt32Op $strides), (convertIntAttrTo32Bit $begin_mask),
(convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
(convertIntAttrTo32Bit $new_axis_mask),
(convertIntAttrTo32Bit $shrink_axis_mask))>;

View File

@ -148,7 +148,6 @@ DECL_CONVERT_OP(MatrixDiagV3);
DECL_CONVERT_OP(Pack); DECL_CONVERT_OP(Pack);
DECL_CONVERT_OP(Split); DECL_CONVERT_OP(Split);
DECL_CONVERT_OP(SplitV); DECL_CONVERT_OP(SplitV);
DECL_CONVERT_OP(StridedSlice);
DECL_CONVERT_OP(Unpack); DECL_CONVERT_OP(Unpack);
DECL_CONVERT_OP(RandomUniform); DECL_CONVERT_OP(RandomUniform);
@ -325,81 +324,6 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite(
return success(); return success();
} }
Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
Value attribute,
ArrayRef<int32_t> padding_val, int* mask) {
DenseIntElementsAttr dense_elem_attr;
SmallVector<int32_t, 8> padded_val;
auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
if (!ranked_attr_type ||
!matchPattern(attribute, m_Constant(&dense_elem_attr))) {
// If the input attribute is neither ranked type nor constant, we
// can't do any padding. Instead we just return it.
return attribute;
}
for (const auto& idx : dense_elem_attr.getIntValues()) {
padded_val.push_back(idx.getSExtValue());
}
auto attr_dim_count = ranked_attr_type.getShape()[0];
int full_dim_count = padding_val.size();
for (int i = attr_dim_count; i < full_dim_count; ++i) {
padded_val.push_back(padding_val[i]);
if (mask) *mask |= 1 << i;
}
auto type =
RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
auto attr = DenseElementsAttr::get<int32_t>(type, padded_val);
return rewriter.create<ConstantOp>(op->getLoc(), type, attr);
}
LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
auto ranked_input_type =
tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
if (!ranked_input_type) {
// If input is not a ranked tensor, we can't deduce the padding dimensions
// from it, so we just do a plain conversion here.
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
tf_strided_slice_op.strides(),
rewriter.getI32IntegerAttr(tf_strided_slice_op.begin_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.end_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
return success();
}
int num_input_dims = ranked_input_type.getRank();
// Pad `begin` array with zero values and update the `begin_mask`.
SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0);
int begin_mask = tf_strided_slice_op.begin_mask();
Value padded_begin = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask);
// Pad `end` array with `input_shape` and update the `end_mask`.
int end_mask = tf_strided_slice_op.end_mask();
auto input_shape = ranked_input_type.getShape();
SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end());
Value padded_end = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask);
// Pad `strides` array with ones.
SmallVector<int32_t, 8> strides_pad_val(num_input_dims, 1);
Value padded_strides = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
padded_begin, padded_end, padded_strides,
rewriter.getI32IntegerAttr(begin_mask),
rewriter.getI32IntegerAttr(end_mask),
rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
return success();
}
LogicalResult ConvertTFUnpackOp::matchAndRewrite( LogicalResult ConvertTFUnpackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const { Operation* op, PatternRewriter& rewriter) const {
auto tf_unpack_op = cast<TF::UnpackOp>(op); auto tf_unpack_op = cast<TF::UnpackOp>(op);
@ -769,8 +693,8 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
patterns patterns
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op, .insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp, ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
ConvertTFAssertOp, ConvertTFRandomUniformOp>(context); ConvertTFRandomUniformOp>(context);
// Ophint python converter converted tf node pattern. // Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm, patterns.insert<LegalizeUnidirectionalSequenceLstm,

View File

@ -376,14 +376,17 @@ void PrepareQuantizePass::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType(); bool is_signed = quant_specs_.IsSignedInferenceType();
int bit_width = quant_specs_.GetQuantizationTypeWidth(); int bit_width = quant_specs_.GetQuantizationTypeWidth();
bool quantization_aware_training_mode = ContainsQuantizeOps(func); // When this is true, the quantizer will try its best to extract the
// Enforce fixed output range for post-training quantization and // quantization parameters from the op quantization property and constant
// when the model has quantization emulation ops, unless it was disabled // content. This is also set to true when the `quantize_allowlist` and
// explicitly by the flag. // `quantize_signed` test flags are enabled.
bool enforced_output_range = bool eager_quantize = ContainsQuantizeOps(func) ||
(quant_specs_.post_training_quantization || (!quantize_allowlist.empty() || quantize_signed);
quantization_aware_training_mode) && // Infer the tensor range for the activation ops and weight constants unless
!quant_specs_.disable_enforced_fixed_output_range; // it is disabled explicitly.
bool infer_tensor_range =
(quant_specs_.post_training_quantization || eager_quantize) &&
!quant_specs_.disable_infer_tensor_range;
if (is_signed) { if (is_signed) {
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx); patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters. // Convert quant stats to int8 quantization parameters.
@ -403,7 +406,7 @@ void PrepareQuantizePass::runOnFunction() {
// values (tensors). // values (tensors).
ApplyQuantizationParamsPropagation( ApplyQuantizationParamsPropagation(
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel, func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
GetOpQuantSpec, enforced_output_range); GetOpQuantSpec, infer_tensor_range);
ConvertMlirQuantOpsToTFLQuantOps(func); ConvertMlirQuantOpsToTFLQuantOps(func);
} }

View File

@ -712,6 +712,23 @@ struct ConvertTFStridedSlice : public RewritePattern {
return success(); return success();
} }
void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
SmallVectorImpl<int32_t> &val,
SmallVectorImpl<int32_t> &padded_val,
ArrayRef<int32_t> padding_val,
int *mask) const {
for (const auto &idx : dense_elem_attr.getIntValues()) {
val.push_back(idx.getSExtValue());
padded_val.push_back(idx.getSExtValue());
}
int attr_dim_count = val.size();
int full_dim_count = padding_val.size();
for (int i = attr_dim_count; i < full_dim_count; ++i) {
padded_val.push_back(padding_val[i]);
if (mask) *mask |= 1 << i;
}
}
LogicalResult matchAndRewrite(Operation *op, LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op); TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
@ -719,17 +736,102 @@ struct ConvertTFStridedSlice : public RewritePattern {
// Handle new axis mask. // Handle new axis mask.
if (strided_slice_op.new_axis_mask() != 0) { if (strided_slice_op.new_axis_mask() != 0) {
// We currently don't handle simultaneous shrink_ and new_axis masks. // We currently don't handle simultaneous shrink_ and new_axis masks.
if (strided_slice_op.shrink_axis_mask()) { if (!strided_slice_op.shrink_axis_mask()) {
return failure(); return RewriteNewAxisMask(strided_slice_op, rewriter);
} }
return RewriteNewAxisMask(strided_slice_op, rewriter);
} }
// Handle ellipsis mask. // Handle ellipsis mask.
if (strided_slice_op.ellipsis_mask() != 0) { if (strided_slice_op.ellipsis_mask() != 0) {
return RewriteEllipsisMask(strided_slice_op, rewriter); return RewriteEllipsisMask(strided_slice_op, rewriter);
} }
return failure();
auto ranked_input_type =
strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
if (!ranked_input_type) {
return failure();
}
auto begin_attr = strided_slice_op.begin();
auto end_attr = strided_slice_op.end();
auto strides_attr = strided_slice_op.strides();
auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
auto strides_attr_type =
strides_attr.getType().dyn_cast<RankedTensorType>();
DenseIntElementsAttr begin_elem_attr;
DenseIntElementsAttr end_elem_attr;
DenseIntElementsAttr strides_elem_attr;
if (!begin_attr_type ||
!matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
return failure();
}
if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
return failure();
}
if (!strides_attr_type ||
!matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
return failure();
}
SmallVector<int32_t, 4> begin, end, strides;
SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
int num_input_dims = ranked_input_type.getRank();
SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
auto input_shape = ranked_input_type.getShape();
SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
int begin_mask = strided_slice_op.begin_mask();
int end_mask = strided_slice_op.end_mask();
PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
padding_begin, &begin_mask);
PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
&end_mask);
PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
padding_strides, nullptr);
if (begin == padded_begin && end == padded_end &&
strides == padded_strides &&
begin_mask == strided_slice_op.begin_mask() &&
end_mask == strided_slice_op.end_mask()) {
return failure();
}
auto begin_end_type =
RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
auto new_begin_attr = rewriter.create<ConstantOp>(
op->getLoc(), begin_end_type,
DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
auto new_end_attr = rewriter.create<ConstantOp>(
op->getLoc(), begin_end_type,
DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
auto strides_type =
RankedTensorType::get({static_cast<long>(padded_strides.size())},
rewriter.getIntegerType(32));
auto new_strides_attr = rewriter.create<ConstantOp>(
op->getLoc(), strides_type,
DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
auto attribute_type = rewriter.getIntegerType(64);
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
op, strided_slice_op.output().getType(), strided_slice_op.input(),
new_begin_attr, new_end_attr, new_strides_attr,
rewriter.getIntegerAttr(attribute_type, begin_mask),
rewriter.getIntegerAttr(attribute_type, end_mask),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.ellipsis_mask()),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.new_axis_mask()),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.shrink_axis_mask()));
return success();
} }
}; };

View File

@ -337,6 +337,7 @@ cc_library(
":tensorflow", ":tensorflow",
"//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect", "@llvm-project//mlir:Dialect",
@ -859,6 +860,7 @@ cc_library(
"transforms/mark_ops_for_outside_compilation.cc", "transforms/mark_ops_for_outside_compilation.cc",
"transforms/materialize_mlir_passthrough_op.cc", "transforms/materialize_mlir_passthrough_op.cc",
"transforms/optimize.cc", "transforms/optimize.cc",
"transforms/outside_compiled_to_host_launch.cc",
"transforms/parallel_execute_to_islands.cc", "transforms/parallel_execute_to_islands.cc",
"transforms/parallelize_embedding_params_ops_pass.cc", "transforms/parallelize_embedding_params_ops_pass.cc",
"transforms/promote_resources_to_args.cc", "transforms/promote_resources_to_args.cc",

View File

@ -31,7 +31,7 @@ limitations under the License.
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td"
def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> { def TF_AbsOp : TF_Op<"Abs", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the absolute value of a tensor."; let summary = "Computes the absolute value of a tensor.";
let description = [{ let description = [{
@ -1002,13 +1002,13 @@ reverse of SpaceToBatch. See below for a precise description.
TF_Tensor:$output TF_Tensor:$output
); );
let verifier = [{
return Verify(*this);
}];
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>; TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>;
TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
} }
def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> { def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> {
@ -1486,7 +1486,7 @@ def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
let hasFolder = 1; let hasFolder = 1;
} }
def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> { def TF_CeilOp : TF_Op<"Ceil", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns element-wise smallest integer not less than x."; let summary = "Returns element-wise smallest integer not less than x.";
let arguments = (ins let arguments = (ins
@ -3502,8 +3502,8 @@ tf.math.equal(x, y) ==> array([True, True])
}]; }];
let arguments = (ins let arguments = (ins
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, TF_Tensor:$x,
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, TF_Tensor:$y,
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
); );
@ -3843,8 +3843,8 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation."; let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation.";
let arguments = (ins let arguments = (ins
F32Tensor:$gradients, TF_Float32Tensor:$gradients,
F32Tensor:$inputs, TF_Float32Tensor:$inputs,
DefaultValuedAttr<F32Attr, "-6.0f">:$min, DefaultValuedAttr<F32Attr, "-6.0f">:$min,
DefaultValuedAttr<F32Attr, "6.0f">:$max, DefaultValuedAttr<F32Attr, "6.0f">:$max,
@ -3853,7 +3853,7 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
); );
let results = (outs let results = (outs
F32Tensor:$backprops TF_Float32Tensor:$backprops
); );
} }
@ -3911,19 +3911,19 @@ def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradien
let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation."; let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation.";
let arguments = (ins let arguments = (ins
F32Tensor:$gradients, TF_Float32Tensor:$gradients,
F32Tensor:$inputs, TF_Float32Tensor:$inputs,
F32Tensor:$min, TF_Float32Tensor:$min,
F32Tensor:$max, TF_Float32Tensor:$max,
DefaultValuedAttr<I64Attr, "8">:$num_bits, DefaultValuedAttr<I64Attr, "8">:$num_bits,
DefaultValuedAttr<BoolAttr, "false">:$narrow_range DefaultValuedAttr<BoolAttr, "false">:$narrow_range
); );
let results = (outs let results = (outs
F32Tensor:$backprops_wrt_input, TF_Float32Tensor:$backprops_wrt_input,
F32Tensor:$backprop_wrt_min, TF_Float32Tensor:$backprop_wrt_min,
F32Tensor:$backprop_wrt_max TF_Float32Tensor:$backprop_wrt_max
); );
} }
@ -4026,7 +4026,7 @@ fill([2, 3], 9) ==> [[9, 9, 9]
]; ];
} }
def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> { def TF_FloorOp : TF_Op<"Floor", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns element-wise largest integer not greater than x."; let summary = "Returns element-wise largest integer not greater than x.";
let arguments = (ins let arguments = (ins
@ -4977,13 +4977,13 @@ $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
}]; }];
let arguments = (ins let arguments = (ins
F32Tensor:$predictions, TF_Float32Tensor:$predictions,
TF_I32OrI64Tensor:$targets, TF_I32OrI64Tensor:$targets,
TF_I32OrI64Tensor:$k TF_I32OrI64Tensor:$k
); );
let results = (outs let results = (outs
I1Tensor:$precision TF_BoolTensor:$precision
); );
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
@ -6912,7 +6912,7 @@ retained with length 1.
]; ];
} }
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
let summary = "Performs max pooling on the input."; let summary = "Performs max pooling on the input.";
let arguments = (ins let arguments = (ins
@ -6936,6 +6936,9 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInter
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; } SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; } SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation); LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
// TF_LayoutSensitiveInterface:
StringRef GetOptimalLayout(const RuntimeDevices& devices);
LogicalResult UpdateDataFormat(StringRef data_format);
}]; }];
} }
@ -7852,8 +7855,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> {
}]; }];
let arguments = (ins let arguments = (ins
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, TF_Tensor:$x,
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, TF_Tensor:$y,
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
); );
@ -8031,7 +8034,7 @@ times by rerunning "MakeIterator".
); );
} }
def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> { def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns a tensor of ones with the same shape and type as x."; let summary = "Returns a tensor of ones with the same shape and type as x.";
let arguments = (ins let arguments = (ins
@ -8429,6 +8432,10 @@ def TF_QrOp : TF_Op<"Qr", [NoSideEffect]> {
Computes the QR decomposition of each inner matrix in `tensor` such that Computes the QR decomposition of each inner matrix in `tensor` such that
`tensor[..., :, :] = q[..., :, :] * r[..., :,:])` `tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
Currently, the gradient for the QR decomposition is well-defined only when
the first `P` columns of the inner matrix are linearly independent, where
`P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`.
```python ```python
# a is a tensor. # a is a tensor.
# q is a tensor of orthonormal matrices. # q is a tensor of orthonormal matrices.
@ -9114,7 +9121,7 @@ most one RecvTPUEmbeddingActivations op in the TPU graph.
TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>; TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>;
} }
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> { def TF_ReluOp : TF_Op<"Relu", [Idempotent, NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
let summary = "Computes rectified linear: `max(features, 0)`."; let summary = "Computes rectified linear: `max(features, 0)`.";
let description = [{ let description = [{
@ -9140,7 +9147,7 @@ array([ 0., 0., -0., 3.], dtype=float32)
}]; }];
} }
def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> { def TF_Relu6Op : TF_Op<"Relu6", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`."; let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`.";
let arguments = (ins let arguments = (ins
@ -10538,7 +10545,7 @@ bitwise_ops.right_shift(lhs, rhs)
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> { def TF_RintOp : TF_Op<"Rint", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns element-wise integer closest to x."; let summary = "Returns element-wise integer closest to x.";
let description = [{ let description = [{
@ -10605,7 +10612,7 @@ roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>; TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>;
} }
def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> { def TF_RoundOp : TF_Op<"Round", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = [{ let summary = [{
Rounds the values of a tensor to the nearest integer, element-wise. Rounds the values of a tensor to the nearest integer, element-wise.
}]; }];
@ -11335,7 +11342,7 @@ Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> { def TF_SignOp : TF_Op<"Sign", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns an element-wise indication of the sign of a number."; let summary = "Returns an element-wise indication of the sign of a number.";
let description = [{ let description = [{
@ -12345,14 +12352,14 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
} }
def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> { def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> {
let summary = [{ let summary = "Picks the best counter-based RNG algorithm based on device.";
Picks the best counter-based RNG algorithm based on device.
}];
let description = [{ let description = [{
This op picks the best counter-based RNG algorithm based on device. This op picks the best counter-based RNG algorithm based on device.
}]; }];
let arguments = (ins);
let results = (outs let results = (outs
TF_Int32Tensor:$alg TF_Int32Tensor:$alg
); );
@ -14088,73 +14095,35 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are
scattered onto an existing tensor (as opposed to a zero-tensor). If the memory scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
for the existing tensor cannot be re-used, a copy is made and updated. for the existing tensor cannot be re-used, a copy is made and updated.
If `indices` contains duplicates, then their updates are accumulated (summed). If `indices` contains duplicates, then we pick the last update for the index.
**WARNING**: The order in which updates are applied is nondeterministic, so the If an out of bound index is found on CPU, an error is returned.
output will be nondeterministic if `indices` contains duplicates -- because
of some numerical approximation issues, numbers summed in different order **WARNING**: There are some GPU specific semantics for this operation.
may yield different results. - If an out of bound index is found, the index is ignored.
- The order in which updates are applied is nondeterministic, so the output
will be nondeterministic if `indices` contains duplicates.
`indices` is an integer tensor containing indices into a new tensor of shape `indices` is an integer tensor containing indices into a new tensor of shape
`shape`. The last dimension of `indices` can be at most the rank of `shape`: `shape`.
indices.shape[-1] <= shape.rank * `indices` must have at least 2 axes: `(num_updates, index_depth)`.
* The last axis of `indices` is how deep to index into `tensor` so this index
depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim`
The last dimension of `indices` corresponds to indices into elements if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements.
(if `indices.shape[-1] = shape.rank`) or slices if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input
(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of `tensor`.
`shape`. `updates` is a tensor with shape
indices.shape[:-1] + shape[indices.shape[-1]:] Each `update` has a rank of `tensor.rank - indices.shape[-1]`.
The overall shape of `updates` is:
The simplest form of scatter is to insert individual elements in a tensor by ```
index. For example, say we want to insert 4 scattered elements in a rank-1 indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
tensor with 8 elements. ```
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> For usage examples see the python [tf.tensor_scatter_nd_update](
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd1.png" alt> https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
</div>
In Python, this scatter operation would look like this:
>>> indices = tf.constant([[4], [3], [1], [7]])
>>> updates = tf.constant([9, 10, 11, 12])
>>> tensor = tf.ones([8], dtype=tf.int32)
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32)
We can also, insert entire slices of a higher rank tensor all at once. For
example, if we wanted to insert two slices in the first dimension of a
rank-3 tensor with two matrices of new values.
In Python, this scatter operation would look like this:
>>> indices = tf.constant([[0], [2]])
>>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
... [7, 7, 7, 7], [8, 8, 8, 8]],
... [[5, 5, 5, 5], [6, 6, 6, 6],
... [7, 7, 7, 7], [8, 8, 8, 8]]])
>>> tensor = tf.ones([4, 4, 4], dtype=tf.int32)
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy())
[[[5 5 5 5]
[6 6 6 6]
[7 7 7 7]
[8 8 8 8]]
[[1 1 1 1]
[1 1 1 1]
[1 1 1 1]
[1 1 1 1]]
[[5 5 5 5]
[6 6 6 6]
[7 7 7 7]
[8 8 8 8]]
[[1 1 1 1]
[1 1 1 1]
[1 1 1 1]
[1 1 1 1]]]
Note that on CPU, if an out of bound index is found, an error is returned.
On GPU, if an out of bound index is found, the index is ignored.
}]; }];
let arguments = (ins let arguments = (ins
@ -15080,7 +15049,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
}]; }];
let arguments = (ins let arguments = (ins
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand, Arg<TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
Arg<TF_I32OrI64Tensor, [{Array containing the starting indices of the slices we gather.}]>:$start_indices, Arg<TF_I32OrI64Tensor, [{Array containing the starting indices of the slices we gather.}]>:$start_indices,
Arg<TF_I32OrI64Tensor, [{slice_sizes[i] is the bounds for the slice on dimension i.}]>:$slice_sizes, Arg<TF_I32OrI64Tensor, [{slice_sizes[i] is the bounds for the slice on dimension i.}]>:$slice_sizes,
@ -15089,7 +15058,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
); );
let results = (outs let results = (outs
TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
); );
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
@ -15457,7 +15426,7 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> { def TF_ZerosLikeOp : TF_Op<"ZerosLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns a tensor of zeros with the same shape and type as x."; let summary = "Returns a tensor of zeros with the same shape and type as x.";
let arguments = (ins let arguments = (ins

View File

@ -684,12 +684,23 @@ body: A function that takes a list of tensors and returns another
FlatSymbolRefAttr:$cond, FlatSymbolRefAttr:$cond,
FlatSymbolRefAttr:$body, FlatSymbolRefAttr:$body,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations, DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
// Used to map StatelessWhile and While op defined in TensorFlow to a common // Used to map StatelessWhile and While op defined in TensorFlow to a common
// op. // op.
BoolAttr:$is_stateless BoolAttr:$is_stateless,
// In TensorFlow, While has a special behavior where if `output_shapes`
// attribute is not empty, those shapes are used in its shape function
// as result shapes instead of propagating operand shapes as result shapes.
// This allows for different result shapes from operand shapes. While these
// shapes are imported and set as a part of the result type, there is no
// indicator differentiating between having no output shapes compared to
// having all unranked shapes. Thus this attribute is set to determine
// which shape function behavior to use for this op, specifically
// propagating operand shapes as result shapes when this attribute is not
// set, or preserving result shapes as is when this attribute is set.
UnitAttr:$shape_invariant
); );
let results = (outs let results = (outs
@ -697,6 +708,7 @@ body: A function that takes a list of tensors and returns another
); );
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
let verifier = [{ let verifier = [{
return Verify(*this); return Verify(*this);
@ -752,12 +764,23 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
let arguments = (ins let arguments = (ins
Variadic<AnyTensor>:$input, Variadic<AnyTensor>:$input,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations, DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
// Used to map StatelessWhile and While op defined in TensorFlow to a common // Used to map StatelessWhile and While op defined in TensorFlow to a common
// op. // op.
BoolAttr:$is_stateless BoolAttr:$is_stateless,
// In TensorFlow, While has a special behavior where if `output_shapes`
// attribute is not empty, those shapes are used in its shape function
// as result shapes instead of propagating operand shapes as result shapes.
// This allows for different result shapes from operand shapes. While these
// shapes are imported and set as a part of the result type, there is no
// indicator differentiating between having no output shapes compared to
// having all unranked shapes. Thus this attribute is set to determine
// which shape function behavior to use for this op, specifically
// propagating operand shapes as result shapes when this attribute is not
// set, or preserving result shapes as is when this attribute is set.
UnitAttr:$shape_invariant
); );
let results = (outs Variadic<AnyTensor>:$output); let results = (outs Variadic<AnyTensor>:$output);

View File

@ -2467,6 +2467,33 @@ LogicalResult MaxPoolOp::FoldOperandsPermutation(
permutation, this, {{"strides", strides()}, {"ksize", ksize()}}); permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
} }
LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) {
StringRef src_data_format = data_format();
auto perm = GetDataFormatPermutation(src_data_format, new_data_format);
if (perm.empty()) return failure();
// Update data_format attribute and result types.
if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this)))
return failure();
stridesAttr(ShuffleArrayAttr(strides(), perm));
explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2));
ksizeAttr(ShuffleArrayAttr(ksize(), perm));
return success();
}
StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) {
// Keep current data format if no GPUs are available or if explicit placement
// does not allow to use GPU for this operation.
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
return data_format();
// Defaults to NCHW.
return "NCHW";
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MaxPoolGradOp // MaxPoolGradOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -41,3 +41,34 @@ func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32
// CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32> // CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32>
// CHECK: %[[V1]] : tensor<3x5x7xf32> // CHECK: %[[V1]] : tensor<3x5x7xf32>
} }
// CHECK-LABEL: @broadcast_eq
func @broadcast_eq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> {
%cst = constant dense<[5, 7]> : tensor<2xi32>
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
%1 = "tf.Equal"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
return %1 : tensor<5x7xi1>
// CHECK: %[[V0:.*]] = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
// CHECK: %[[V0]] : tensor<5x7xi1>
}
// CHECK-LABEL: @broadcast_neq
func @broadcast_neq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> {
%cst = constant dense<[5, 7]> : tensor<2xi32>
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
%1 = "tf.NotEqual"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
return %1 : tensor<5x7xi1>
// CHECK: %[[V0:.*]] = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
// CHECK: %[[V0]] : tensor<5x7xi1>
}
// CHECK-LABEL: @broadcast_both_operand
func @broadcast_both_operand(%arg0: tensor<7xf32>, %arg1: tensor<5x1xf32>) -> tensor<5x7xf32> {
%cst = constant dense<[5, 7]> : tensor<2xi64>
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi64>) -> tensor<5x7xf32>
%1 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5x1xf32>, tensor<2xi64>) -> tensor<5x7xf32>
%2 = "tf.Add"(%0, %1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
return %2 : tensor<5x7xf32>
// CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
// CHECK: %[[V0]] : tensor<5x7xf32>
}

View File

@ -1,4 +1,4 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s # RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1,WhileWithOutputShapes:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
# Verify that TensorFlow While and StatelessWhile ops are mapped to the # Verify that TensorFlow While and StatelessWhile ops are mapped to the
# composite While op in MLIR with is_stateless attribute set accordingly to # composite While op in MLIR with is_stateless attribute set accordingly to
@ -6,6 +6,7 @@
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile") # CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile")
# CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile") # CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile")
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} shape_invariant{{.*}} -> (tensor<i32>, tensor<*xf32>) loc("WhileWithOutputShapes")
node { node {
name: "StatefulWhile" name: "StatefulWhile"
@ -73,6 +74,51 @@ node {
experimental_debug_info { experimental_debug_info {
} }
} }
node {
name: "WhileWithOutputShapes"
op: "While"
input: "iter"
input: "val"
attr {
key: "T"
value {
list {
type: DT_INT32
type: DT_FLOAT
}
}
}
attr {
key: "body"
value {
func {
name: "body"
}
}
}
attr {
key: "cond"
value {
func {
name: "cond"
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
shape {
unknown_rank: true
}
}
}
}
experimental_debug_info {
}
}
node { node {
name: "main" name: "main"
op: "_Retval" op: "_Retval"
@ -107,6 +153,23 @@ node {
} }
} }
} }
node {
name: "main2"
op: "_Retval"
input: "WhileWithOutputShapes:1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 2
}
}
}
node { node {
name: "iter" name: "iter"
op: "Placeholder" op: "Placeholder"

View File

@ -82,3 +82,20 @@ func @bias_add_nchw(%arg0: tensor<1x256x150x150xf32>, %arg1: tensor<256xf32>) ->
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW", device = ""} : (tensor<1x256x150x150xf32>, tensor<256xf32>) -> tensor<1x256x150x150xf32> %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW", device = ""} : (tensor<1x256x150x150xf32>, tensor<256xf32>) -> tensor<1x256x150x150xf32>
return %0 : tensor<1x256x150x150xf32> return %0 : tensor<1x256x150x150xf32>
} }
// CHECK-LABEL: maxpool_nchw
func @maxpool_nchw(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> {
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
// CHECK: %[[R0:.*]] = "tf.Transpose"(%arg0, %[[CST]])
// CHECK: %[[R1:.*]] = "tf.MaxPool"(%[[R0]]) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}
// CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: "tf.Transpose"(%[[R1]], %[[CST_0]])
%0 = "tf.MaxPool"(%arg0)
{
data_format = "NCHW",
ksize = [1, 1, 3, 3],
padding = "SAME",
strides = [1, 1, 2, 2]
} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
return %0 : tensor<1x64x56x56xf32>
}

View File

@ -1703,3 +1703,110 @@ func @convert_iota_3d() -> tensor<5x7x9xi32> {
return %0 : tensor<5x7x9xi32> return %0 : tensor<5x7x9xi32>
} }
// CHECK-LABEL: func @convert_avgpool_valid(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32>
// CHECK: }
func @convert_avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
%0 = mhlo.constant dense<0.0> : tensor<f32>
%1 = mhlo.constant dense<9.0> : tensor<4x7x7x8xf32>
%2 = "mhlo.reduce_window"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%5 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%5) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
%3 = mhlo.divide %2, %1 : tensor<4x7x7x8xf32>
return %3 : tensor<4x7x7x8xf32>
}
// CHECK-LABEL: func @convert_avgpool_valid_rw(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32>
// CHECK: }
func @convert_avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
%0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
%1 = mhlo.constant dense<0.0> : tensor<f32>
%2 = "mhlo.reduce_window"(%arg0, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%6) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
%3 = "mhlo.reduce_window"(%0, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%6) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
%4 = mhlo.divide %2, %3 : tensor<4x7x7x8xf32>
return %4 : tensor<4x7x7x8xf32>
}
// CHECK-LABEL: func @convert_avgpool_valid_3d(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) {data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]} : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
// CHECK: return %[[VAL_1]] : tensor<4x7x7x7x8xf32>
// CHECK: }
func @convert_avgpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
%0 = mhlo.constant dense<0.0> : tensor<f32>
%1 = mhlo.constant dense<27.0> : tensor<4x7x7x7x8xf32>
%2 = "mhlo.reduce_window"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%5 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%5) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<5xi64>,
padding = dense<0> : tensor<5x2xi64>,
window_dilations = dense<1> : tensor<5xi64>,
window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<4x16x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x7x8xf32>
%3 = mhlo.divide %2, %1 : tensor<4x7x7x7x8xf32>
return %3 : tensor<4x7x7x7x8xf32>
}
// CHECK-LABEL: func @convert_avgpool_same(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
// CHECK: return %[[VAL_1]] : tensor<4x8x8x8xf32>
// CHECK: }
func @convert_avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
%0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
%1 = mhlo.constant dense<0.0> : tensor<f32>
%2 = "mhlo.reduce_window"(%arg0, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%6) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
%3 = "mhlo.reduce_window"(%0, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%6) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
%4 = mhlo.divide %2, %3 : tensor<4x8x8x8xf32>
return %4 : tensor<4x8x8x8xf32>
}

View File

@ -244,9 +244,9 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
// CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} // CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
// CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]]) // CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
// CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64} // CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1) // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.AddV2"([[PADDINGS]]#0, [[PADDINGS]]#1)
// CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>} // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]]) // CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.AddV2"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
// CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]]) // CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
// CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1) // CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1)
// CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0) // CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0)
@ -338,10 +338,10 @@ func @fake_quant_with_min_max_args(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-DAG: [[VAL5:%.+]] = "tf.ClipByValue"(%arg0, [[VAL2]], [[VAL1]]) // CHECK-DAG: [[VAL5:%.+]] = "tf.ClipByValue"(%arg0, [[VAL2]], [[VAL1]])
// CHECK-DAG: [[VAL6:%.+]] = "tf.Sub"([[VAL5]], [[VAL2]]) // CHECK-DAG: [[VAL6:%.+]] = "tf.Sub"([[VAL5]], [[VAL2]])
// CHECK-DAG: [[VAL7:%.+]] = "tf.Mul"([[VAL6]], [[VAL0]]) // CHECK-DAG: [[VAL7:%.+]] = "tf.Mul"([[VAL6]], [[VAL0]])
// CHECK-DAG: [[VAL8:%.+]] = "tf.Add"([[VAL7]], [[VAL4]]) // CHECK-DAG: [[VAL8:%.+]] = "tf.AddV2"([[VAL7]], [[VAL4]])
// CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]]) // CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
// CHECK-DAG: [[VAL10:%.+]] = "tf.Mul"([[VAL9]], [[VAL3]]) // CHECK-DAG: [[VAL10:%.+]] = "tf.Mul"([[VAL9]], [[VAL3]])
// CHECK-DAG: [[VAL11:%.+]] = "tf.Add"([[VAL10]], [[VAL2]]) // CHECK-DAG: [[VAL11:%.+]] = "tf.AddV2"([[VAL10]], [[VAL2]])
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 1.0 : f32, min = -1.0 : f32, narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32> %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 1.0 : f32, min = -1.0 : f32, narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: return [[VAL11]] // CHECK: return [[VAL11]]
@ -361,7 +361,7 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>,
// CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]]) // CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
// CHECK-DAG: [[VAL10:%.+]] = "tf.Sub"([[VAL8]], [[VAL9]]) // CHECK-DAG: [[VAL10:%.+]] = "tf.Sub"([[VAL8]], [[VAL9]])
// CHECK-DAG: [[VAL11:%.+]] = "tf.Less"([[VAL10]], [[VAL3]]) // CHECK-DAG: [[VAL11:%.+]] = "tf.Less"([[VAL10]], [[VAL3]])
// CHECK-DAG: [[VAL12:%.+]] = "tf.Add"([[VAL2]], [[VAL9]]) // CHECK-DAG: [[VAL12:%.+]] = "tf.AddV2"([[VAL9]], [[VAL2]])
// CHECK-DAG: [[VAL13:%.+]] = "tf.Select"([[VAL11]], [[VAL9]], [[VAL12]]) // CHECK-DAG: [[VAL13:%.+]] = "tf.Select"([[VAL11]], [[VAL9]], [[VAL12]])
// CHECK-DAG: [[VAL14:%.+]] = "tf.ClipByValue"([[VAL13]], [[VAL0]], [[VAL1]]) : // CHECK-DAG: [[VAL14:%.+]] = "tf.ClipByValue"([[VAL13]], [[VAL0]], [[VAL1]]) :
// CHECK-DAG: [[VAL15:%.+]] = "tf.Sub"([[VAL0]], [[VAL14]]) // CHECK-DAG: [[VAL15:%.+]] = "tf.Sub"([[VAL0]], [[VAL14]])
@ -371,10 +371,10 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>,
// CHECK-DAG: [[VAL19:%.+]] = "tf.ClipByValue"(%arg0, [[VAL17]], [[VAL18]]) // CHECK-DAG: [[VAL19:%.+]] = "tf.ClipByValue"(%arg0, [[VAL17]], [[VAL18]])
// CHECK-DAG: [[VAL20:%.+]] = "tf.Sub"([[VAL19]], [[VAL17]]) // CHECK-DAG: [[VAL20:%.+]] = "tf.Sub"([[VAL19]], [[VAL17]])
// CHECK-DAG: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL6]]) // CHECK-DAG: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL6]])
// CHECK-DAG: [[VAL22:%.+]] = "tf.Add"([[VAL21]], [[VAL3]]) // CHECK-DAG: [[VAL22:%.+]] = "tf.AddV2"([[VAL21]], [[VAL3]])
// CHECK-DAG: [[VAL23:%.+]] = "tf.Floor"([[VAL22]]) // CHECK-DAG: [[VAL23:%.+]] = "tf.Floor"([[VAL22]])
// CHECK-DAG: [[VAL24:%.+]] = "tf.Mul"([[VAL23]], [[VAL5]]) // CHECK-DAG: [[VAL24:%.+]] = "tf.Mul"([[VAL23]], [[VAL5]])
// CHECK-DAG: [[VAL25:%.+]] = "tf.Add"([[VAL24]], [[VAL17]]) // CHECK-DAG: [[VAL25:%.+]] = "tf.AddV2"([[VAL24]], [[VAL17]])
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32> %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: return [[VAL25]] // CHECK: return [[VAL25]]
@ -746,7 +746,7 @@ func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>} // CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
// CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]]) // CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
// CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} // CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
// CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]]) // CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]])
// CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]]) // CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
%0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
@ -761,7 +761,7 @@ func @round_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>} // CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
// CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]]) // CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
// CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} // CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
// CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]]) // CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]])
// CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]]) // CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
%0 = "tf.Round"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "tf.Round"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>

View File

@ -1,12 +1,13 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s // RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) {
%0:2 = tf_executor.graph { %0:3 = tf_executor.graph {
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile") %outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile") %outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32> %outputs_6:2, %control_7 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, shape_invariant} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("WhileWithOutputShapes")
tf_executor.fetch %outputs_2#1, %outputs_4#1, %outputs_6#1 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
} }
return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32> return %0#0, %0#1, %0#2 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
} }
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> { func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
@ -36,6 +37,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
// CHECK-NOT: name: // CHECK-NOT: name:
// CHECK: op: "While" // CHECK: op: "While"
// CHECK-NOT: is_stateless // CHECK-NOT: is_stateless
// CHECK-NOT: shape_invariant
// CHECK: attr { // CHECK: attr {
// CHECK: key: "output_shapes" // CHECK: key: "output_shapes"
// CHECK: value { // CHECK: value {
@ -54,6 +56,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
// CHECK-NOT: name: // CHECK-NOT: name:
// CHECK: op: "StatelessWhile" // CHECK: op: "StatelessWhile"
// CHECK-NOT: is_stateless // CHECK-NOT: is_stateless
// CHECK-NOT: shape_invariant
// CHECK: attr { // CHECK: attr {
// CHECK: key: "output_shapes" // CHECK: key: "output_shapes"
// CHECK: value { // CHECK: value {
@ -67,3 +70,20 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
// CHECK: } // CHECK: }
// CHECK: } // CHECK: }
// CHECK: name: "WhileWithOutputShapes"
// CHECK-NOT: name:
// CHECK: op: "While"
// CHECK-NOT: is_stateless
// CHECK-NOT: shape_invariant
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
// CHECK: list {
// CHECK: shape {
// CHECK: dim {
// CHECK: size: 5
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }

View File

@ -0,0 +1,153 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-outside-compiled-to-host-launch | FILECHECK_OPTS="" FileCheck %s
// expected-error@+1 {{'module' op bad 'tf.devices' attribute at index 0, not a string}}
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = [1]} {
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
func @invalid_device_attribute() -> tensor<?xi32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"() : () -> tensor<?xi32>
%2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
return %0 : tensor<?xi32>
}
}
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
func @empty_outside_compilation_attribute() -> tensor<?xi32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"() : () -> tensor<?xi32>
// expected-error@+1 {{'tf.B' op requires non empty '_xla_outside_compilation' string attribute}}
%2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
return %0 : tensor<?xi32>
}
}
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// Tests that TPU cluster with no outside compilation does not generate launch op.
// CHECK-LABEL: func @no_outside_compilation
// CHECK-NOT: "tf_device.launch"
func @no_outside_compilation() -> tensor<?xi32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"() : () -> tensor<?xi32>
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// Tests the launch wrap of a single outside compiled cluster with no input or output dependencies.
// CHECK-LABEL: func @nodep_single_outside_compilation
func @nodep_single_outside_compilation() -> () {
// CHECK: "tf.A"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return
}
// Tests the launch wrap of a single outside compiled cluster with data parallelism.
// CHECK-LABEL: func @single_outside_compilation_with_replicate
func @single_outside_compilation_with_replicate(%arg0: tensor<?xi32>) -> () {
// CHECK: "tf.A"
// CHECK: tf_device.replicate
// CHECK-NEXT: "tf_device.cluster"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.C"
// CHECK-NOT: _xla_outside_compilation
// CHECK: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
"tf_device.cluster"() ( {
"tf.B"() : () -> ()
"tf.C"(%ri_0) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.D"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
tf_device.return
}
return
}
// Tests launch wrap of a single outside compiled cluster with input/output.
// CHECK-LABEL: func @single_outside_compilation_input_output
func @single_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]])
// CHECK: tf_device.return %[[B_OUTPUT]]
// CHECK: "tf.C"(%[[LAUNCH_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32>
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests launch wrap of multiple outside compiled cluster with input/output.
// CHECK-LABEL: func @multiple_outside_compilation_input_output
func @multiple_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]])
// CHECK: tf_device.return %[[B_OUTPUT]]
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[LAUNCH_OUTPUT]])
// CHECK-NEXT: %[[LAUNCH_OUTPUT2:[0-9]*]] = "tf_device.launch"
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]])
// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[D_OUTPUT]])
// CHECK: tf_device.return %[[E_OUTPUT]]
// CHECK: "tf.F"(%[[LAUNCH_OUTPUT2]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32>
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32>
%7 = "tf.E"(%6) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32>
%8 = "tf.F"(%7) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %8 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
}

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project
@ -38,6 +39,12 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern {
LogicalResult matchAndRewrite(Operation* op, LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override; PatternRewriter& rewriter) const override;
private:
template <typename Op>
LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter) const;
}; };
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> { class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
@ -47,7 +54,27 @@ class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite( LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const { Operation* op, PatternRewriter& rewriter) const {
if (!op->hasTrait<OpTrait::ResultsBroadcastableShape>()) return failure(); if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
return RewriteOp(op, rewriter);
// tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
// incompatible_shape_error is `true` (what is also checked by the verifier).
if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
return failure();
}
template <typename Op>
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
Operation* op, PatternRewriter& rewriter) const {
auto eq_op = llvm::dyn_cast_or_null<Op>(op);
if (eq_op && eq_op.incompatible_shape_error()) return RewriteOp(op, rewriter);
return failure();
}
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
Operation* op, PatternRewriter& rewriter) const {
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1) if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
return failure(); return failure();
@ -56,6 +83,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>(); op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
if (!result_type || !result_type.hasStaticShape()) return failure(); if (!result_type || !result_type.hasStaticShape()) return failure();
bool changed = false;
for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) { for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
// Check that the i'th operand is a broadcast. // Check that the i'th operand is a broadcast.
auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>( auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
@ -89,10 +117,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
// Update the operand of the op to be the operand of the broadcast. // Update the operand of the op to be the operand of the broadcast.
rewriter.updateRootInPlace( rewriter.updateRootInPlace(
op, [&]() { op->getOpOperand(i).set(broadcast.input()); }); op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
return success(); changed = true;
} }
return success(changed);
return failure();
} }
void BroadcastFoldPass::runOnFunction() { void BroadcastFoldPass::runOnFunction() {

View File

@ -112,8 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
LogicalResult ConvertWhileOp(WhileOp while_op) { LogicalResult ConvertWhileOp(WhileOp while_op) {
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>( auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
while_op.getLoc(), while_op.getResultTypes(), while_op.input(), while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
while_op.output_shapes(), while_op.parallel_iterations(), while_op.parallel_iterations(), while_op.is_stateless(),
while_op.is_stateless()); while_op.shape_invariant());
CopyDeviceAndUnderscoredAttributes(while_op, while_region); CopyDeviceAndUnderscoredAttributes(while_op, while_region);
YieldOp cond_yield = YieldOp cond_yield =

View File

@ -21,15 +21,18 @@ limitations under the License.
#include <numeric> #include <numeric>
#include <vector> #include <vector>
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
@ -44,6 +47,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/kernel_shape_util.h"
#include "tensorflow/core/lib/math/math_util.h"
namespace mlir { namespace mlir {
namespace TF { namespace TF {
@ -479,18 +483,16 @@ template <typename ReductionOp>
LogicalResult MatchBinaryReduceFunction(mlir::Region &function) { LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
Block &body = function.front(); Block &body = function.front();
if (body.getNumArguments() != 2) return failure(); if (body.getNumArguments() != 2) return failure();
if (body.getOperations().size() != 2) return failure();
ReductionOp reduce_op = dyn_cast<ReductionOp>(body.front());
if (!reduce_op) return failure();
if (reduce_op.lhs() != body.getArgument(0) ||
reduce_op.rhs() != body.getArgument(1))
return failure();
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back()); mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
if (!return_op) return failure(); if (!return_op) return failure();
if (return_op.getNumOperands() != 1 || if (return_op.getNumOperands() != 1) return failure();
return_op.results().front() != reduce_op)
ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
return_op.getOperands().front().getDefiningOp());
if (!reduce_op) return failure();
if (reduce_op.lhs() != body.getArgument(0) ||
reduce_op.rhs() != body.getArgument(1))
return failure(); return failure();
return success(); return success();
@ -654,6 +656,190 @@ class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
} }
}; };
// Maps the following represenattions of AvgPool in MHLO into a tf.AvgPool{3D}
// operation when they cleanly map to 2D or 3D average pool with VALID or SAME
// padding:
// * div(reduce_sum_window(x), constant(sizeof(window)))
// * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DivOp div_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
auto rw =
dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
if (!rw) return failure();
// Check that the reduce-window is a sum-reduce-window.
if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
return failure();
// Check that this is a floating point reduce window with a rank of 4 or 5.
RankedTensorType rw_type = rw.getType().dyn_cast<RankedTensorType>();
if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
rw_type.getRank() <= 3 || rw_type.getRank() > 5)
return failure();
// Check that the Div op doesn't do broadcasting on the output of the reduce
// window.
if (div_op.getType() != rw.getType()) return failure();
// tf.avg_pool need at least 3 dimensions (batch, spatial, channel)
const uint64_t rank = rw.window_dimensions().size();
if (rank <= 2) return failure();
// If the init value isn't zero then it can't be an average pool.
if (!isFloatZero(rw.init_value())) return failure();
llvm::SmallVector<int64_t, 5> window_strides;
if (rw.window_strides().hasValue()) {
window_strides.insert(window_strides.end(),
rw.window_strides()->getValues<int64_t>().begin(),
rw.window_strides()->getValues<int64_t>().end());
} else {
window_strides.resize(rank, 1);
}
llvm::SmallVector<int64_t, 10> padding;
if (rw.padding().hasValue()) {
padding.insert(padding.begin(),
rw.padding()->getValues<int64_t>().begin(),
rw.padding()->getValues<int64_t>().end());
} else {
padding.resize(2 * rank, 0);
}
// Check that we don't do any reduction along the batch (first) and channel
// (last) dimensions.
const uint64_t batch_dim = 0;
const uint64_t channel_dim = rank - 1;
if (rw.window_dimensions().getValue<int64_t>({batch_dim}) != 1 ||
rw.window_dimensions().getValue<int64_t>({channel_dim}) != 1 ||
window_strides[batch_dim] != 1 || window_strides[channel_dim] != 1 ||
padding[2 * batch_dim] != 0 || padding[2 * batch_dim + 1] != 0 ||
padding[2 * channel_dim] != 0 || padding[2 * channel_dim + 1] != 0)
return failure();
if (rw.window_dilations().hasValue() &&
!(rw.window_dilations()->isSplat() &&
rw.window_dilations()->getSplatValue<APInt>() == 1))
return failure();
if (rw.base_dilations().hasValue() &&
!(rw.base_dilations()->isSplat() &&
rw.base_dilations()->getSplatValue<APInt>() == 1))
return failure();
DenseFPElementsAttr divisor;
if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
// If the divisor is a constant then check that it matches with the number
// of elements inside the window what is required for a VALID AvgPool.
if (!divisor.isSplat()) return failure();
int64_t window_size = 1;
for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
window_size *= w;
}
if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
return failure();
// Check that we have no padding.
if (!llvm::all_of(padding, [](int64_t i) { return i == 0; }))
return failure();
return replaceWithAvgPool(
div_op, rw.operand(),
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
window_strides, "VALID", rewriter);
}
auto rw_rhs =
dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
if (rw_rhs) {
// Check that RHS is a sum-reduce-window.
if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
return failure();
// Check that the RHS is a reduce_window over a constant 1 input with 0 as
// the init value.
DenseFPElementsAttr rhs_input;
if (!isFloatZero(rw_rhs.init_value()) ||
!matchPattern(rw_rhs.operand(), m_Constant(&rhs_input)) ||
!rhs_input.isSplat() ||
!rhs_input.getSplatValue<APFloat>().isExactlyValue(1.0))
return failure();
// Check that the two reduce window have the same window configuration.
if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
rw.window_strides() != rw_rhs.window_strides() ||
rw.window_dilations() != rw_rhs.window_dilations() ||
rw.base_dilations() != rw_rhs.base_dilations() ||
rw.padding() != rw_rhs.padding())
return failure();
if (llvm::all_of(padding, [](int64_t i) { return i == 0; }))
return replaceWithAvgPool(
div_op, rw.operand(),
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
window_strides, "VALID", rewriter);
RankedTensorType input_type =
rw.operand().getType().dyn_cast<RankedTensorType>();
RankedTensorType output_type = rw.getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) return failure();
// Check that the individual padding values are corresponding to SAME
// padding from TensorFlow.
for (uint64_t i = 1; i < rank - 1; ++i) {
int64_t padding_size =
(output_type.getShape()[i] - 1) * window_strides[i] +
rw.window_dimensions().getValue<int64_t>({i}) -
input_type.getShape()[i];
if (padding[2 * i] !=
tensorflow::MathUtil::FloorOfRatio(padding_size, int64_t(2)) ||
padding[2 * i + 1] !=
tensorflow::MathUtil::CeilOfRatio(padding_size, int64_t(2)))
return failure();
}
return replaceWithAvgPool(
div_op, rw.operand(),
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
window_strides, "SAME", rewriter);
}
return failure();
}
private:
bool isFloatZero(Value value) const {
DenseFPElementsAttr initial_value;
return matchPattern(value, m_Constant(&initial_value)) &&
initial_value.getNumElements() == 1 &&
initial_value.getValue<APFloat>({}).isZero();
}
LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
llvm::ArrayRef<int64_t> ksizes,
llvm::ArrayRef<int64_t> kstrides,
llvm::StringRef padding,
ConversionPatternRewriter &rewriter) const {
if (ksizes.size() == 4) {
rewriter.replaceOpWithNewOp<AvgPoolOp>(
op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
rewriter.getStringAttr("NHWC"));
return success();
} else if (ksizes.size() == 5) {
rewriter.replaceOpWithNewOp<AvgPool3DOp>(
op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
rewriter.getStringAttr("NDHWC"));
return success();
}
return failure();
}
};
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> { class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TF::TensorFlowDialect>(); registry.insert<TF::TensorFlowDialect>();
@ -794,10 +980,10 @@ static PassRegistration<LegalizeHloToTf> pass(
void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns, void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
MLIRContext *context) { MLIRContext *context) {
patterns->insert<ConvertAvgPoolOp, ConvertConvOp, ConvertSliceOp,
ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(context);
populateWithGenerated(context, *patterns); populateWithGenerated(context, *patterns);
patterns->insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
ConvertReduceOpToTfMin, ConvertReduceOpToTfSum,
ConvertIotaOpToTfRange>(context);
} }
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() { std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {

View File

@ -332,12 +332,13 @@ class LowerDynamicStitchOp : public RewritePattern {
class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern { class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
public: public:
explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context) explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
: RewritePattern(FakeQuantWithMinMaxVarsOp::getOperationName(), : RewritePattern(
{SubOp::getOperationName(), ConstOp::getOperationName(), FakeQuantWithMinMaxVarsOp::getOperationName(),
MulOp::getOperationName(), FloorOp::getOperationName(), {AddV2Op::getOperationName(), SubOp::getOperationName(),
ClipByValueOp::getOperationName(), ConstOp::getOperationName(), MulOp::getOperationName(),
DivOp::getOperationName(), RoundOp::getOperationName()}, FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
1, context) {} DivOp::getOperationName(), RoundOp::getOperationName()},
1, context) {}
LogicalResult matchAndRewrite(Operation *src_op, LogicalResult matchAndRewrite(Operation *src_op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
@ -419,8 +420,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
op.getLoc(), op.getLoc(),
DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty))); DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
quantized_input = rewriter.create<AddOp>(op.getLoc(), input_ty, quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
quantized_input, half_val); quantized_input, half_val);
quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input); quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
@ -428,8 +429,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
Value output = rewriter.create<MulOp>(op.getLoc(), input_ty, Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
quantized_input, quant_to_float); quantized_input, quant_to_float);
output = output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
rewriter.create<AddOp>(op.getLoc(), input_ty, output, nudged_float_min); nudged_float_min);
rewriter.replaceOp(op, {output}); rewriter.replaceOp(op, {output});
return success(); return success();
@ -811,7 +812,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
CastOp::getOperationName(), CastOp::getOperationName(),
ConstOp::getOperationName(), ConstOp::getOperationName(),
ConcatV2Op::getOperationName(), ConcatV2Op::getOperationName(),
AddOp::getOperationName(), AddV2Op::getOperationName(),
PadOp::getOperationName(), PadOp::getOperationName(),
SplitOp::getOperationName(), SplitOp::getOperationName(),
UnpackOp::getOperationName(), UnpackOp::getOperationName(),
@ -907,8 +908,8 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
auto paddings_split = rewriter.create<UnpackOp>( auto paddings_split = rewriter.create<UnpackOp>(
loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings, loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
rewriter.getI64IntegerAttr(1)); rewriter.getI64IntegerAttr(1));
auto paddings_sum = rewriter.create<AddOp>(loc, paddings_split.getResult(0), auto paddings_sum = rewriter.create<AddV2Op>(
paddings_split.getResult(1)); loc, paddings_split.getResult(0), paddings_split.getResult(1));
auto input_shape_tensor = rewriter.create<ConstOp>( auto input_shape_tensor = rewriter.create<ConstOp>(
loc, loc,
@ -918,7 +919,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
// padded_shape_tensor is the shape of padded. // padded_shape_tensor is the shape of padded.
auto padded_shape_tensor = auto padded_shape_tensor =
rewriter.create<AddOp>(loc, paddings_sum, input_shape_tensor); rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
auto zero_i32 = rewriter.create<ConstOp>( auto zero_i32 = rewriter.create<ConstOp>(
loc, GetScalarOfType(rewriter.getIntegerType(32), 0)); loc, GetScalarOfType(rewriter.getIntegerType(32), 0));

View File

@ -237,7 +237,7 @@ def : Pat<(TF_RoundOp:$res TF_FloatTensor:$input),
(TF_SubOp $input, (TF_FloorOp:$floor $input)), (TF_SubOp $input, (TF_FloorOp:$floor $input)),
(TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))), (TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))),
$floor, $floor,
(TF_AddOp (TF_AddV2Op
(TF_ConstOp (GetScalarOfType<1> $input)), $floor))>; (TF_ConstOp (GetScalarOfType<1> $input)), $floor))>;

View File

@ -0,0 +1,184 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kDeviceAttr[] = "device";
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
// Mapping for `_xla_outside_compilation` attribute to ops of a cluster.
using OutsideClusterMap =
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<Operation*, 8>, 8>;
// This pass wraps ops with the same `_xla_outside_compilation`
// attribute value in a tf_device.launch op with host device assignment.
//
// A simple example:
// "tf_device.cluster"() ( {
// "tf.A"()
// "tf.B"() {_xla_outside_compilation = "cluster1"}
// "tf.C"()
// tf_device.return
// }) {num_cores_per_replica = 1, topology = "", device_assignment = []}
//
// Would become the following ops (unimportant attribute, type are omitted):
// "tf_device.cluster"() ( {
// "tf.A"()
// "tf_device.launch"() {
// "tf.B"() {_xla_outside_compilation = "cluster1"}
// tf_device.return
// } {device = "TPU_REPLICATED_HOST"} : () -> ()
// "tf.C"()
// tf_device.return
// }) {num_cores_per_replica = 1, topology = "", device_assignment = []}
//
struct OutsideCompiledToHostLaunch
: public PassWrapper<OutsideCompiledToHostLaunch, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
// attribute into `clusters` This returns an error if a
// `_xla_outside_compilation` attribute of an op is empty.
LogicalResult CollectAndGroupOutsideClusterOps(Block* block,
OutsideClusterMap* clusters) {
auto walk_result = block->walk([&](Operation* op) {
if (auto attr = op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
if (attr.getValue().empty()) {
op->emitOpError() << "requires non empty '"
<< kXlaOutsideCompilationAttr << "' string attribute";
return WalkResult::interrupt();
}
auto it = clusters->try_emplace(attr.getValue());
it.first->getSecond().push_back(op);
}
return WalkResult::advance();
});
return failure(walk_result.wasInterrupted());
}
// Extracts all externally used outputs of `cluster_ops`.
llvm::SmallSetVector<Value, 4> GetExternalOutputs(
llvm::ArrayRef<Operation*> cluster_ops) {
llvm::SmallSetVector<Value, 4> external_outputs;
llvm::SmallPtrSet<Operation*, 4> host_cluster_ops_set;
for (auto op : cluster_ops) {
op->walk([&](Operation* host_cluster_op) {
host_cluster_ops_set.insert(host_cluster_op);
});
}
for (Operation* op : cluster_ops) {
for (Operation* user : op->getUsers()) {
bool is_external = llvm::none_of(
host_cluster_ops_set,
[&](Operation* cluster_op) { return user == cluster_op; });
if (!is_external) continue;
for (Value v : user->getOperands()) {
if (v.getDefiningOp() == op) external_outputs.insert(v);
}
}
}
return external_outputs;
}
void WrapClusterInLaunch(llvm::ArrayRef<Operation*> cluster_ops,
llvm::StringRef host_device) {
auto* last_cluster_op = cluster_ops.back();
OpBuilder builder(last_cluster_op);
llvm::SmallVector<Type, 4> launch_output_types;
auto external_outputs = GetExternalOutputs(cluster_ops);
for (const auto& external_output : external_outputs)
launch_output_types.push_back(external_output.getType());
auto launch_op = builder.create<tf_device::LaunchOp>(
last_cluster_op->getLoc(), builder.getStringAttr(host_device),
/*result_types=*/launch_output_types);
for (auto result : llvm::zip(external_outputs, launch_op.getResults())) {
std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
}
launch_op.body().push_back(new Block);
builder.setInsertionPointToEnd(&launch_op.GetBody());
auto* return_op =
builder
.create<tf_device::ReturnOp>(last_cluster_op->getLoc(),
external_outputs.getArrayRef())
.getOperation();
MLIRContext* context = launch_op.getContext();
for (Operation* cluster_op : cluster_ops) {
cluster_op->removeAttr(
Identifier::get(kXlaOutsideCompilationAttr, context));
cluster_op->removeAttr(Identifier::get(kDeviceAttr, context));
cluster_op->moveBefore(return_op);
}
}
void OutsideCompiledToHostLaunch::runOnOperation() {
// Get runtime devices information from the closest parent module.
auto module = getOperation();
mlir::TF::RuntimeDevices devices;
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
return signalPassFailure();
auto result = module.walk([&](tf_device::ClusterOp tpu_cluster) {
OutsideClusterMap clusters;
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
&clusters)))
return WalkResult::interrupt();
if (clusters.empty()) return WalkResult::advance();
std::string host_device;
tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster,
&host_device);
for (const auto& cluster : clusters) {
WrapClusterInLaunch(cluster.getSecond(), host_device);
}
return WalkResult::advance();
});
if (result.wasInterrupted()) return signalPassFailure();
}
} // anonymous namespace
std::unique_ptr<OperationPass<ModuleOp>>
CreateOutsideCompiledToHostLaunchPass() {
return std::make_unique<OutsideCompiledToHostLaunch>();
}
static PassRegistration<OutsideCompiledToHostLaunch> pass(
"tf-outside-compiled-to-host-launch",
"Wraps ops with ithe same _xla_outside_compiled attribute in "
"tf_device.launch on replicated host device.");
} // namespace TFTPU
} // namespace mlir

View File

@ -345,6 +345,11 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUColocateCompositeResourceOps();
// run-time according to compilation result. // run-time according to compilation result.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass(); std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass();
// Creates a pass that wraps ops with the same `_xla_outside_compilation`
// attribute value in a tf_device.launch op with host device assignment.
std::unique_ptr<OperationPass<ModuleOp>>
CreateOutsideCompiledToHostLaunchPass();
// Creates a pass that groups outside compiled operations (CPU ops inside TPU // Creates a pass that groups outside compiled operations (CPU ops inside TPU
// cluster) into clusters that can be extracted and run on the CPU. // cluster) into clusters that can be extracted and run on the CPU.
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>

View File

@ -398,8 +398,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
OpBuilder builder(while_region); OpBuilder builder(while_region);
auto while_op = builder.create<WhileOp>( auto while_op = builder.create<WhileOp>(
while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name, while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
while_region.output_shapes(), while_region.parallel_iterations(), while_region.parallel_iterations(), while_region.is_stateless(),
while_region.is_stateless()); while_region.shape_invariant());
CopyDeviceAndUnderscoredAttributes(while_region, while_op); CopyDeviceAndUnderscoredAttributes(while_region, while_op);
// Redirect old results to new results. // Redirect old results to new results.

View File

@ -255,8 +255,7 @@ TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
OpBuilder& builder) { OpBuilder& builder) {
auto host_side_while = builder.create<TF::WhileRegionOp>( auto host_side_while = builder.create<TF::WhileRegionOp>(
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{}, loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
/*output_shapes=*/builder.getArrayAttr({}), parallel_iterations, parallel_iterations, is_stateless, /*shape_invariant=*/false);
is_stateless);
// Create empty else branch region. // Create empty else branch region.
auto& body = host_side_while.body(); auto& body = host_side_while.body();

View File

@ -155,6 +155,9 @@ StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore(
if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst)) if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst))
attrs_to_ignore.insert("is_stateless"); attrs_to_ignore.insert("is_stateless");
if (llvm::isa<mlir::TF::WhileOp>(inst))
attrs_to_ignore.insert("shape_invariant");
return attrs_to_ignore; return attrs_to_ignore;
} }

View File

@ -971,6 +971,16 @@ StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
etype.getContext())); etype.getContext()));
} }
if (node.IsWhileNode()) {
auto* output_shapes = node.attrs().Find("output_shapes");
auto* element_types = node.attrs().Find("T");
if (output_shapes && !output_shapes->list().shape().empty()) {
const auto& output_shape = output_shapes->list().shape(idx);
const auto& element_type = element_types->list().type(idx);
return ConvertToMlirTensorType(output_shape, element_type, &builder);
}
}
// Returns a simple, more conservative unranked tensor type. // Returns a simple, more conservative unranked tensor type.
auto default_type = [&]() -> StatusOr<mlir::Type> { auto default_type = [&]() -> StatusOr<mlir::Type> {
mlir::Type element_type; mlir::Type element_type;
@ -1907,7 +1917,13 @@ Status ImporterBase::ConvertNode(const Node& node) {
// Case/If/While op in MLIR and add the differentiating attribute. // Case/If/While op in MLIR and add the differentiating attribute.
if (node.IsCaseNode()) composite_control_flow_op("Case"); if (node.IsCaseNode()) composite_control_flow_op("Case");
if (node.IsIfNode()) composite_control_flow_op("If"); if (node.IsIfNode()) composite_control_flow_op("If");
if (node.IsWhileNode()) composite_control_flow_op("While"); if (node.IsWhileNode()) {
composite_control_flow_op("While");
auto* output_shapes = node.attrs().Find("output_shapes");
if (output_shapes && !output_shapes->list().shape().empty())
result.attributes.push_back(
builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
}
// Register the mapping between the TF node and the newly created operation. // Register the mapping between the TF node and the newly created operation.
node_values_[node.id()] = node_values_[node.id()] =
@ -3420,6 +3436,8 @@ SavedModelSignatureDefImporterLite::ConvertGraph(
const std::vector<std::pair<std::string, TensorInfo>>& inputs, const std::vector<std::pair<std::string, TensorInfo>>& inputs,
const std::vector<std::pair<std::string, TensorInfo>>& outputs, const std::vector<std::pair<std::string, TensorInfo>>& outputs,
const std::vector<std::string> control_outputs) { const std::vector<std::string> control_outputs) {
VLOG(1) << "Importing Signature: " << name;
GraphImportConfig specs; GraphImportConfig specs;
specs.prune_unused_nodes = true; specs.prune_unused_nodes = true;
specs.inputs = ParseInputArrays(inputs); specs.inputs = ParseInputArrays(inputs);
@ -3491,6 +3509,9 @@ SavedModelSignatureDefImporterLite::ParseInputArrays(
// Only dense tensor is supported. // Only dense tensor is supported.
DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName); DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName);
VLOG(1) << "Importing Signature Input: input_name = " << iter.first
<< ", tensor_info = " << tensor_info.DebugString();
ArrayInfo array_info; ArrayInfo array_info;
array_info.imported_dtype = tensor_info.dtype(); array_info.imported_dtype = tensor_info.dtype();
array_info.shape = tensor_info.tensor_shape(); array_info.shape = tensor_info.tensor_shape();

View File

@ -68,6 +68,7 @@ distribute_py_test(
tags = [ tags = [
"no_cuda_asan", # b/173431253 "no_cuda_asan", # b/173431253
"no_oss", "no_oss",
"notap", # b/173661843
"notsan", # b/173246447 "notsan", # b/173246447
], ],
xla_enable_strict_auto_jit = False, # b/173254861 xla_enable_strict_auto_jit = False, # b/173254861

View File

@ -90,8 +90,17 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
} }
// Legalize only hlo operations to lhlo, keep the rest as tensors. // Partial bufferization: Transforms inparticular HLO operation to their
pm.addPass(mlir::kernel_gen::transforms::CreateHloBufferizePass()); // corresponding LHLO operations and converts the function signature. Leaves
// shape operations untouched.
pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass(
/*allow_partial_bufferization=*/true));
// Run CSE to ensure that loads and stores to the same location get recognized
// as such.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Forward stores to buffers to loads.
pm.addNestedPass<mlir::FuncOp>(xla::mlir_gpu::createStoreForwardingPass());
// Clean up the IR for further processing. // Clean up the IR for further processing.
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
@ -100,18 +109,22 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
// needed. // needed.
llvm::SmallVector<unsigned, 4> tiling_for_unrolling; llvm::SmallVector<unsigned, 4> tiling_for_unrolling;
llvm::SmallVector<int64_t, 4> as_int64; llvm::SmallVector<int64_t, 4> as_int64;
if (!unroll_factors.empty()) { tiling_for_unrolling.reserve(tile_sizes.size());
tiling_for_unrolling.reserve(tile_sizes.size()); for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
for (auto pair : llvm::zip(tile_sizes, unroll_factors)) { tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair)); as_int64.push_back(std::get<1>(pair));
as_int64.push_back(std::get<1>(pair));
}
} else {
tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
} }
tiling_for_unrolling.append(
tile_sizes.drop_front(unroll_factors.size()).begin(), tile_sizes.end());
// Transform LHLO operations to LinAlg. // Transform LHLO operations to LinAlg.
pm.addNestedPass<mlir::FuncOp>( pm.addNestedPass<mlir::FuncOp>(
::mlir::lmhlo::createLegalizeLhloToLinalgPass()); ::mlir::lmhlo::createLegalizeLhloToLinalgPass());
if (!gpu_binary_only) {
// Find candidates for buffer reuse. This is only successful if buffer size
// equality can be determined based on `linalg.generic` operations.
pm.addNestedPass<mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateBufferReusePass());
}
// Fuse linalg operations. // Fuse linalg operations.
pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass( pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass(
/*use_parallel_loops=*/true, tiling_for_unrolling)); /*use_parallel_loops=*/true, tiling_for_unrolling));
@ -141,11 +154,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
// Some basic cleanup. // Some basic cleanup.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
if (!gpu_binary_only) {
// Find candidates for buffer reuse.
pm.addNestedPass<mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateBufferReusePass());
}
// Greedily map the remaining loop to GPU hardware dimensions. // Greedily map the remaining loop to GPU hardware dimensions.
pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass()); pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass());
@ -157,7 +165,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass()); pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass()); pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass());
pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass(64)); pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass(64));
// TODO(herhut): Enabled this to avoid leaks once fixed. // TODO(herhut): Enabled this to avoid leaks once fixed.
// pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass()); // pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass());
@ -189,10 +197,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
pm.addPass(::mlir::createLowerAffinePass()); pm.addPass(::mlir::createLowerAffinePass());
// Constraints are removed as late as possible and before lowering to CFG. // Constraints are removed as late as possible and before lowering to CFG.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createConvertShapeConstraintsPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createConvertShapeConstraintsPass());
if (embed_memref_prints) {
pm.addNestedPass<::mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass());
}
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
// TODO(herhut): Remove this pass once the LowerToCFG pass can handle it. // TODO(herhut): Remove this pass once the LowerToCFG pass can handle it.
pm.addNestedPass<mlir::FuncOp>( pm.addNestedPass<mlir::FuncOp>(
@ -200,6 +204,10 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
pm.addPass(::mlir::createLowerToCFGPass()); pm.addPass(::mlir::createLowerToCFGPass());
// Map allocs, asserts, etc. to the tensorflow framework. // Map allocs, asserts, etc. to the tensorflow framework.
pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass()); pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass());
if (embed_memref_prints) {
pm.addNestedPass<::mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass());
}
if (failed(pm.run(module))) { if (failed(pm.run(module))) {
return InternalError("Lowering to GPU kernels failed."); return InternalError("Lowering to GPU kernels failed.");
} }

View File

@ -49,12 +49,58 @@ func @local_reuse_with_memref_maps(
iterator_types = ["parallel"] iterator_types = ["parallel"]
} ins(%arg : memref<?xi64, offset: 2, strides: [3]>) } ins(%arg : memref<?xi64, offset: 2, strides: [3]>)
outs(%result : memref<?xi64, offset: 2, strides: [3]>) { outs(%result : memref<?xi64, offset: 2, strides: [3]>) {
^bb0(%a: i64, %b: i64): ^bb0(%a : i64, %b : i64):
linalg.yield %a : i64 linalg.yield %a : i64
} }
return %result : memref<?xi64, offset: 2, strides: [3]> return %result : memref<?xi64, offset: 2, strides: [3]>
} }
// CHECK-LABEL: @memref_reinterpret_cast_alias
func @memref_reinterpret_cast_alias(%arg : memref<f32>, %n : index)
-> memref<?xf32> attributes {tf_entry} {
%c0 = constant 0 : index
%reinterpreted = memref_reinterpret_cast %arg to
offset: [0],
sizes: [%n],
strides: [%c0]: memref<f32> to memref<?xf32>
// CHECK: alloc
// CHECK-SAME: reuse_input_candidates = [0 : index]
%result = alloc(%n) : memref<?xf32>
// reinterpreted (arg) and result are of same size.
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%reinterpreted : memref<?xf32>) outs(%result : memref<?xf32>) {
^bb0(%a : f32, %b : f32):
linalg.yield %a : f32
}
return %result : memref<?xf32>
}
// CHECK-LABEL: @memref_cast_alias
func @memref_cast_alias(%arg : memref<*xf32>, %n : index)
-> memref<?xf32> attributes {tf_entry} {
%casted = memref_cast %arg : memref<*xf32> to memref<?xf32>
// CHECK: alloc
// CHECK-SAME: reuse_input_candidates = [0 : index]
%result = alloc(%n) : memref<?xf32>
// reinterpreted (arg) and result are of same size.
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%casted : memref<?xf32>) outs(%result : memref<?xf32>) {
^bb0(%a : f32, %b : f32):
linalg.yield %a : f32
}
return %result : memref<?xf32>
}
// CHECK-LABEL: @indirect_size_equality // CHECK-LABEL: @indirect_size_equality
func @indirect_size_equality(%arg0 : memref<?xi64>, func @indirect_size_equality(%arg0 : memref<?xi64>,
%arg1 : memref<?xi64>, %arg1 : memref<?xi64>,
@ -65,7 +111,7 @@ func @indirect_size_equality(%arg0 : memref<?xi64>,
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"] iterator_types = ["parallel"]
} ins(%arg0 : memref<?xi64>) outs(%arg1 : memref<?xi64>) { } ins(%arg0 : memref<?xi64>) outs(%arg1 : memref<?xi64>) {
^bb0(%a: i64, %b: i64): ^bb0(%a : i64, %b : i64):
linalg.yield %a : i64 linalg.yield %a : i64
} }
@ -78,7 +124,7 @@ func @indirect_size_equality(%arg0 : memref<?xi64>,
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"] iterator_types = ["parallel"]
} ins(%arg0 : memref<?xi64>) outs(%result : memref<?xi64>) { } ins(%arg0 : memref<?xi64>) outs(%result : memref<?xi64>) {
^bb0(%a: i64, %b: i64): ^bb0(%a : i64, %b : i64):
linalg.yield %a : i64 linalg.yield %a : i64
} }
@ -317,7 +363,7 @@ func @abs_unranked_i64(%arg : memref<*xi64>,
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"] iterator_types = ["parallel"]
} ins(%flat_arg : memref<?xi64>) outs(%flat_result : memref<?xi64>) { } ins(%flat_arg : memref<?xi64>) outs(%flat_result : memref<?xi64>) {
^bb0(%a: i64, %b: i64): ^bb0(%a : i64, %b : i64):
%c0 = constant 0 : i64 %c0 = constant 0 : i64
%a_pos = cmpi "sge", %a, %c0 : i64 %a_pos = cmpi "sge", %a, %c0 : i64
%a_neg = subi %c0, %a : i64 %a_neg = subi %c0, %a : i64
@ -360,3 +406,41 @@ func @index_element_type(%arg : memref<2x3xindex>) -> memref<2x3xindex>
%result = alloc() : memref<2x3xindex> %result = alloc() : memref<2x3xindex>
return %result : memref<2x3xindex> return %result : memref<2x3xindex>
} }
// Example as it occurs in the `tf.Abs` kernel for `f32`.
// CHECK-LABEL: @abs_f32
func @abs_f32(%arg0: memref<*xf32>) -> memref<*xf32>
attributes {llvm.emit_c_interface, tf_entry} {
%c0 = constant 0 : index
%0 = shape.shape_of %arg0 : memref<*xf32> -> tensor<?xindex>
%1 = shape.num_elements %0 : tensor<?xindex> -> index
// CHECK-LABEL: alloc
// CHECK-SAME: reuse_input_candidates = []
%2 = alloc() : memref<1xindex>
store %1, %2[%c0] : memref<1xindex>
%3 = memref_reshape %arg0(%2)
: (memref<*xf32>, memref<1xindex>) -> memref<?xf32>
%4 = dim %3, %c0 : memref<?xf32>
%5 = index_cast %4 : index to i64
// CHECK-LABEL: alloc
// CHECK-SAME: reuse_input_candidates = []
%6 = alloc() : memref<1xi64>
store %5, %6[%c0] : memref<1xi64>
%7 = load %6[%c0] : memref<1xi64>
%8 = index_cast %7 : i64 to index
// CHECK-LABEL: alloc
// CHECK-SAME: reuse_input_candidates = [0 : index]
%9 = alloc(%8) : memref<?xf32>
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%3 : memref<?xf32>) outs(%9 : memref<?xf32>) {
^bb0(%arg1: f32, %arg2: f32): // no predecessors
%12 = absf %arg1 : f32
linalg.yield %12 : f32
}
%10 = tensor_to_memref %0 : memref<?xindex>
%11 = memref_reshape %9(%10)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
return %11 : memref<*xf32>
}

View File

@ -1,4 +1,4 @@
// RUN: kernel-gen-opt %s --final-bufferize | FileCheck %s // RUN: kernel-gen-opt %s --bufferize | FileCheck %s
// CHECK-LABEL: @extract_element // CHECK-LABEL: @extract_element
// CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> f32 // CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> f32

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --canonicalize --shape-to-descriptors --canonicalize --hlo-bufferize --std-bufferize --func-bufferize | FileCheck %s // RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --canonicalize --shape-to-descriptors --canonicalize --bufferize | FileCheck %s
// Test whether all shape computations required for isinf can be lowered to // Test whether all shape computations required for isinf can be lowered to
// the standard dialect, scf and descriptors. // the standard dialect, scf and descriptors.

View File

@ -24,7 +24,7 @@ func @print_memrefs(
return %output : memref<*xf16> return %output : memref<*xf16>
} }
// CHECK: func private @print_memref_index(memref<*xindex>) // CHECK: func private @print_memref_i64(memref<*xi64>)
// CHECK-LABEL: func @print_memrefs // CHECK-LABEL: func @print_memrefs
@ -33,12 +33,16 @@ func @print_memrefs(
// CHECK: [[NUM_ELEM:%.*]] = alloca() : memref<1xindex> // CHECK: [[NUM_ELEM:%.*]] = alloca() : memref<1xindex>
// CHECK: store {{%.*}}, [[NUM_ELEM]] // CHECK: store {{%.*}}, [[NUM_ELEM]]
// CHECK: [[UNRANKED_NUM_ELEM:%.*]] = memref_cast [[NUM_ELEM]] // CHECK: [[NUM_ELEM_I64:%.*]] = index_cast [[NUM_ELEM]]
// CHECK-NEXT: call @print_memref_index([[UNRANKED_NUM_ELEM]]) // CHECK-SAME: : memref<1xindex> to memref<1xi64>
// CHECK-NEXT: [[UNRANKED_NUM_ELEM:%.*]] = memref_cast [[NUM_ELEM_I64]]
// CHECK-NEXT: call @print_memref_i64([[UNRANKED_NUM_ELEM]])
// CHECK: memref_reshape // CHECK: memref_reshape
// CHECK: tf_framework.alloc // CHECK: tf_framework.alloc
// CHECK: [[UNRANKED_SHAPE:%.*]] = memref_cast [[SHAPE]] // CHECK: [[SHAPE_I64:%.*]] = index_cast [[SHAPE]]
// CHECK-NEXT: call @print_memref_index([[UNRANKED_SHAPE]]) // CHECK-SAME: : memref<?xindex> to memref<?xi64>
// CHECK-NEXT: [[UNRANKED_SHAPE:%.*]] = memref_cast [[SHAPE_I64]]
// CHECK-NEXT: call @print_memref_i64([[UNRANKED_SHAPE]])
// CHECK: memref_reshape // CHECK: memref_reshape

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --shape-to-descriptors --canonicalize --hlo-bufferize --std-bufferize --scf-bufferize --func-bufferize | FileCheck %s // RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --shape-to-descriptors --canonicalize --bufferize | FileCheck %s
// Test whether all shape computations required for tanh can be lowered to // Test whether all shape computations required for tanh can be lowered to
// the standard dialect, scf and descriptors. We check for a sparse pattern here, // the standard dialect, scf and descriptors. We check for a sparse pattern here,

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --hlo-bufferize // RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --bufferize
func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>

View File

@ -75,7 +75,7 @@ extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) {
extern "C" void _mlir_ciface_tf_report_error(void* op_kernel_ctx, extern "C" void _mlir_ciface_tf_report_error(void* op_kernel_ctx,
int32_t error_code, char* msg) { int32_t error_code, char* msg) {
Optional<ErrorCode> symbol = symbolizeErrorCode(error_code); Optional<ErrorCode> symbol = symbolizeErrorCode(error_code);
if (symbol.hasValue()) { if (!symbol.hasValue()) {
LOG(ERROR) << "No valid conversion from integer value = " << error_code LOG(ERROR) << "No valid conversion from integer value = " << error_code
<< "to ErrorCode attribute"; << "to ErrorCode attribute";
return; return;

View File

@ -55,12 +55,14 @@ namespace {
/// A temporary buffer size analysis that is correct but may be incomplete. /// A temporary buffer size analysis that is correct but may be incomplete.
class BufferSizeAnalysis { class BufferSizeAnalysis {
public: public:
explicit BufferSizeAnalysis(FuncOp f) { build(f); } BufferSizeAnalysis(FuncOp f, const BufferAliasAnalysis &aliases) {
build(f, aliases);
}
bool is_same_size(Value a, Value b) { return ecs_.isEquivalent(a, b); } bool is_same_size(Value a, Value b) { return ecs_.isEquivalent(a, b); }
private: private:
void build(FuncOp &f) { void build(FuncOp &f, const BufferAliasAnalysis &aliases) {
auto buffers = find_buffer_values(f); auto buffers = find_buffer_values(f);
// Memrefs with statically known same shape and same symbol-free affine maps // Memrefs with statically known same shape and same symbol-free affine maps
@ -102,10 +104,16 @@ class BufferSizeAnalysis {
} }
}); });
// Operand and result of `reshape_memref_cast` must be of same size. // All aliases of a memref must be of the same underlying buffer size.
f.walk([&](MemRefReshapeOp reshapeOp) { for (auto e : aliases) {
ecs_.unionSets(reshapeOp.result(), reshapeOp.source()); Value value = e.getFirst();
}); if (!value.getType().isa<BaseMemRefType>()) continue;
for (Value alias : e.getSecond()) {
assert(alias.getType().isa<BaseMemRefType>() &&
"Expected aliases of memref to be memrefs.");
ecs_.unionSets(value, alias);
}
}
} }
bool affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as, bool affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as,
@ -181,7 +189,7 @@ class BufferReuseAnalysis {
void find_reuse_candiates(FuncOp &f, BufferAliasAnalysis &aliases) { void find_reuse_candiates(FuncOp &f, BufferAliasAnalysis &aliases) {
Liveness liveness(f); Liveness liveness(f);
BufferSizeAnalysis size_equivalences(f); BufferSizeAnalysis size_equivalences(f, aliases);
f.walk([&](Block *block) { f.walk([&](Block *block) {
find_reuse_candiates(block, aliases, liveness.getLiveness(block), find_reuse_candiates(block, aliases, liveness.getLiveness(block),
size_equivalences, f.getArguments()); size_equivalences, f.getArguments());
@ -204,50 +212,53 @@ class BufferReuseAnalysis {
// Find reuse candidates for the regarded allocation. // Find reuse candidates for the regarded allocation.
SmallVector<int64_t, 2> local_reuse_candidates; SmallVector<int64_t, 2> local_reuse_candidates;
for (auto it : llvm::enumerate(arguments)) { for (BlockArgument old_buffer : arguments) {
int64_t old_buffer_index = it.index();
Value old_buffer = it.value();
if (!old_buffer.getType().isa<BaseMemRefType>()) continue; if (!old_buffer.getType().isa<BaseMemRefType>()) continue;
// Will not reuse buffers of different size as they may be too small. // Size criterion: Do not reuse buffers of different size as they may be
// too small.
if (!size_equivalences.is_same_size(new_buffer, old_buffer)) continue; if (!size_equivalences.is_same_size(new_buffer, old_buffer)) continue;
// Only reuse buffers that are no longer used on first reuse, i.e. they // Lifetime criterion: Only reuse buffers that are no longer used on
// are no longer alive. // first reuse, i.e. they are no longer alive.
bool livetimes_compatible = true; bool lifetimes_compatible = true;
for (Value old_buffer_alias : aliases.resolve(old_buffer)) { for (Value old_buffer_alias : aliases.resolve(old_buffer)) {
if (first_reuse == nullptr) { if (first_reuse == nullptr) {
// If the first use is beyond the end of this block we look at the // If the first use is beyond the end of this block we look at the
// block end. An argument buffer that is already reusable there is // block end. An argument buffer that is already reusable there is
// certainly reusable at any later actual use. // certainly reusable at any later actual use. Otherwise, lifetimes
// are incompatible.
if (liveness->isLiveOut(old_buffer_alias)) { if (liveness->isLiveOut(old_buffer_alias)) {
livetimes_compatible = false; lifetimes_compatible = false;
break; break;
} }
} else { } else {
// A buffer is *not* reusable if // A buffer is reusable if
// i) its last use is after the point of reuse, or // i) its last use is before the point of reuse, or
// ii) its last use is also its first reuse but the operation // ii) its last use is also its first reuse and the operation
// does not allow for local reuse. // allows for local reuse.
// Otherwise, lifetimes are incompatible.
Operation *last_use = Operation *last_use =
liveness->getEndOperation(old_buffer_alias, &block->front()); liveness->getEndOperation(old_buffer_alias, &block->front());
assert(last_use != nullptr && last_use->getBlock() == block && assert(last_use != nullptr && last_use->getBlock() == block &&
"Expected last use in same block."); "Expected last use in same block.");
if (first_reuse->isBeforeInBlock(last_use)) { if (first_reuse->isBeforeInBlock(last_use)) {
livetimes_compatible = false; lifetimes_compatible = false;
break; break;
} }
if (first_reuse == last_use && if (first_reuse == last_use &&
!can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) { !can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) {
livetimes_compatible = false; lifetimes_compatible = false;
break; break;
} }
} }
} }
// All criteria are fulfilled 🙂. if (lifetimes_compatible) {
if (livetimes_compatible) // All criteria are fulfilled 🙂.
int64_t old_buffer_index = old_buffer.getArgNumber();
local_reuse_candidates.push_back(old_buffer_index); local_reuse_candidates.push_back(old_buffer_index);
}
} }
reuse_candidates_[&op] = local_reuse_candidates; reuse_candidates_[&op] = local_reuse_candidates;

View File

@ -48,40 +48,6 @@ namespace kernel_gen {
namespace transforms { namespace transforms {
namespace { namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
struct HloBufferizePass : public HloBufferizePassBase<HloBufferizePass> {
// TODO(b/173201243): Move to tablegen.
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo::LmhloDialect>();
}
public:
void runOnOperation() override {
OwningRewritePatternList patterns;
auto& context = getContext();
ConversionTarget target(context);
target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addIllegalDialect<mhlo::MhloDialect>();
BufferizeTypeConverter converter;
// Configure bufferize pattern for functions and lhlo.
mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
// Configure legality and structural patterns.
populateBufferizeMaterializationLegality(target);
populateShapeStructuralTypeConversionsAndLegality(&context, converter,
patterns, target);
scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
// TODO(herhut) : This could become a real pattern in bufferize pass. What we // TODO(herhut) : This could become a real pattern in bufferize pass. What we
// would need to do is insert a copy to model the semantics correctly. The same // would need to do is insert a copy to model the semantics correctly. The same
// is true for the TensorLoad pattern that is already in there. Then buffer // is true for the TensorLoad pattern that is already in there. Then buffer
@ -104,11 +70,34 @@ class UnrankedTensorStoreTestOnlyPattern
} }
}; };
struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> { // TODO(frgossen): Move this upstream to `populateFuncOpTypeConversionPattern`
// This pattern is merely needed to materialize type casts for return values so
// that they match the function signature conversion.
class ReturnOpTypeConversionPattern
: public OpConversionPattern<mlir::ReturnOp> {
public:
using OpConversionPattern<mlir::ReturnOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
return success();
}
};
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
struct BufferizePass : public BufferizePassBase<BufferizePass> {
explicit BufferizePass(bool allow_partial_bufferization) {
allow_partial_bufferization_ = allow_partial_bufferization;
}
// TODO(b/173201243): Move to tablegen. // TODO(b/173201243): Move to tablegen.
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, scf::SCFDialect, shape::ShapeDialect, registry.insert<AffineDialect, scf::SCFDialect, shape::ShapeDialect,
tf_framework::TFFrameworkDialect>(); tf_framework::TFFrameworkDialect, lmhlo::LmhloDialect>();
} }
public: public:
@ -117,12 +106,17 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
ConversionTarget target(context); ConversionTarget target(context);
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect, target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
tf_framework::TFFrameworkDialect, AffineDialect, tf_framework::TFFrameworkDialect, AffineDialect,
shape::ShapeDialect>(); shape::ShapeDialect, lmhlo::LmhloDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<mhlo::MhloDialect>(); target.addIllegalDialect<mhlo::MhloDialect>();
target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp, target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
TensorFromElementsOp, TensorCastOp, TensorLoadOp, TensorFromElementsOp, TensorCastOp>();
TensorToMemrefOp>();
if (!allow_partial_bufferization_) {
target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
}
// Certain operations are no longer legal on tensors but otherwise are. // Certain operations are no longer legal on tensors but otherwise are.
target.addDynamicallyLegalOp<ConstantOp, SelectOp>([&](Operation* op) { target.addDynamicallyLegalOp<ConstantOp, SelectOp>([&](Operation* op) {
return llvm::none_of(op->getResultTypes(), return llvm::none_of(op->getResultTypes(),
@ -144,8 +138,8 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
return converter.isLegal(inputs) && converter.isLegal(results) && return converter.isLegal(inputs) && converter.isLegal(results) &&
converter.isLegal(&op.getBody()); converter.isLegal(&op.getBody());
}); });
target.addDynamicallyLegalOp<CallOp, ConstantOp, DimOp, RankOp, SelectOp>( target.addDynamicallyLegalOp<CallOp, ConstantOp, DimOp, RankOp, SelectOp,
typesAreLegal); ReturnOp>(typesAreLegal);
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns); mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
@ -160,22 +154,27 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter, scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
patterns, target); patterns, target);
patterns.insert<UnrankedTensorStoreTestOnlyPattern>(&context); patterns.insert<UnrankedTensorStoreTestOnlyPattern>(&context);
patterns.insert<ReturnOpTypeConversionPattern>(converter, &context);
auto module = getOperation(); auto module = getOperation();
if (failed(applyFullConversion(module, target, std::move(patterns)))) { if (allow_partial_bufferization_) {
signalPassFailure(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
} else {
if (failed(
mlir::applyFullConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
} }
} }
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp> > CreateHloBufferizePass() { std::unique_ptr<OperationPass<ModuleOp> > CreateBufferizePass(
return std::make_unique<HloBufferizePass>(); bool allow_partial_bufferization) {
} return std::make_unique<BufferizePass>(allow_partial_bufferization);
std::unique_ptr<OperationPass<ModuleOp> > CreateFinalBufferizePass() {
return std::make_unique<FinalBufferizePass>();
} }
} // namespace transforms } // namespace transforms

View File

@ -37,9 +37,8 @@ namespace {
using tf_framework::TFFrameworkDialect; using tf_framework::TFFrameworkDialect;
Operation* emitCallToFunc(Location loc, StringRef func_name, Operation* emitCallToPrint(Location loc, StringRef func_name, Value arg,
ArrayRef<Type> return_types, ValueRange args, OpBuilder* b) {
OpBuilder* b) {
auto caller_func = auto caller_func =
b->getInsertionBlock()->getParent()->getParentOfType<FuncOp>(); b->getInsertionBlock()->getParent()->getParentOfType<FuncOp>();
auto callee_func = auto callee_func =
@ -49,12 +48,12 @@ Operation* emitCallToFunc(Location loc, StringRef func_name,
auto module = caller_func.getParentOfType<ModuleOp>(); auto module = caller_func.getParentOfType<ModuleOp>();
b->setInsertionPointToStart(module.getBody()); b->setInsertionPointToStart(module.getBody());
auto func_type = auto func_type = FunctionType::get(arg.getType(), /*results=*/llvm::None,
FunctionType::get(args.getTypes(), return_types, b->getContext()); b->getContext());
callee_func = b->create<FuncOp>(module.getLoc(), func_name, func_type); callee_func = b->create<FuncOp>(module.getLoc(), func_name, func_type);
callee_func.setPrivate(); callee_func.setPrivate();
} }
return b->create<CallOp>(loc, callee_func, args); return b->create<CallOp>(loc, callee_func, arg);
} }
void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) { void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) {
@ -70,28 +69,32 @@ void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) {
liveness.getLiveness(op->getBlock())->getEndOperation(memref, op); liveness.getLiveness(op->getBlock())->getEndOperation(memref, op);
b->setInsertionPoint(end_op); b->setInsertionPoint(end_op);
if (element_type.isIndex()) {
element_type = b->getI64Type();
memref_type = MemRefType::get(memref_type.getShape(), element_type,
memref_type.getAffineMaps(),
memref_type.getMemorySpace());
memref = b->create<IndexCastOp>(loc, memref, memref_type);
}
auto unranked_type = auto unranked_type =
UnrankedMemRefType::get(element_type, memref_type.getMemorySpace()); UnrankedMemRefType::get(element_type, memref_type.getMemorySpace());
Value unranked_memref = b->create<MemRefCastOp>(loc, memref, unranked_type); Value unranked_memref = b->create<MemRefCastOp>(loc, memref, unranked_type);
if (element_type.isF32()) { if (element_type.isF32()) {
emitCallToFunc(loc, "print_memref_f32", {}, {unranked_memref}, b); emitCallToPrint(loc, "print_memref_f32", unranked_memref, b);
return; return;
} }
if (element_type.isF64()) { if (element_type.isF64()) {
emitCallToFunc(loc, "print_memref_f64", {}, {unranked_memref}, b); emitCallToPrint(loc, "print_memref_f64", unranked_memref, b);
return;
}
if (element_type.isIndex()) {
emitCallToFunc(loc, "print_memref_index", {}, {unranked_memref}, b);
return; return;
} }
if (element_type.isInteger(32)) { if (element_type.isInteger(32)) {
emitCallToFunc(loc, "print_memref_i32", {}, {unranked_memref}, b); emitCallToPrint(loc, "print_memref_i32", unranked_memref, b);
return; return;
} }
if (element_type.isInteger(64)) { if (element_type.isInteger(64) || element_type.isIndex()) {
emitCallToFunc(loc, "print_memref_i64", {}, {unranked_memref}, b); emitCallToPrint(loc, "print_memref_i64", unranked_memref, b);
return; return;
} }
} }

View File

@ -47,13 +47,10 @@ std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass();
// using memref descriptors. // using memref descriptors.
std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass(); std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass();
// Pass to tranform hlo-level computations on values to their corresponding // Pass to tranform operations on values to their corresponding parts on
// parts on buffers. // buffers.
std::unique_ptr<OperationPass<ModuleOp>> CreateHloBufferizePass(); std::unique_ptr<OperationPass<ModuleOp>> CreateBufferizePass(
bool allow_partial_bufferization = false);
// Pass to tranform late-dialect level computations (essentially all non-hlo
// dialects) on values to their corresponding parts on buffers.
std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass();
// Pass to materialize broadcasts. // Pass to materialize broadcasts.
std::unique_ptr<FunctionPass> CreateMaterializeBroadcastsPass(); std::unique_ptr<FunctionPass> CreateMaterializeBroadcastsPass();

View File

@ -38,16 +38,14 @@ def ShapeToDescriptorsPass : Pass<"shape-to-descriptors", "ModuleOp"> {
let constructor = "transforms::CreateShapeToDescriptorsPass()"; let constructor = "transforms::CreateShapeToDescriptorsPass()";
} }
def HloBufferizePass : Pass<"hlo-bufferize", "ModuleOp"> { def BufferizePass : Pass<"bufferize", "ModuleOp"> {
let summary = "Pass to transform hlo operations on values to buffer based " let summary = "Pass to transform operations on values to buffer-based ones.";
"ones."; let options = [
let constructor = "transforms::CreateHloBufferizePass()"; Option<"allow_partial_bufferization_", "allow-partial-bufferization",
} "bool", /*default=*/"false", "Allow partial bufferization. "
"Value-based operations may remain, e.g. for shape operations.">,
def FinalBufferizePass : Pass<"final-bufferize", "ModuleOp"> { ];
let summary = "Pass to transform operations from all non-hlo dialects on " let constructor = "transforms::CreateBufferizePass()";
"values to buffer-based ones.";
let constructor = "transforms::CreateFinalBufferizePass()";
} }
def MaterializeBroadcastsPass : FunctionPass<"materialize-broadcast"> { def MaterializeBroadcastsPass : FunctionPass<"materialize-broadcast"> {

View File

@ -140,6 +140,7 @@ cc_library(
":translate_cl_options", ":translate_cl_options",
"//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
"//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
@ -148,6 +149,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",

View File

@ -63,8 +63,11 @@ HloModule SelectAndScatter
ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10) ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10)
} }
// CHECK-LABEL: module
// CHECK: global_memref "private" constant @[[$GLOBAL:.*]] : memref<f32> = dense<0.000000e+00>
// CHECK-LABEL: func @main // CHECK-LABEL: func @main
// CHECK: "lmhlo.select_and_scatter" // CHECK: %[[GLOBAL_MEMREF:.*]] = get_global_memref @[[$GLOBAL]] : memref<f32>
// CHECK: "lmhlo.select_and_scatter"(%{{.*}}, %{{.*}}, %[[GLOBAL_MEMREF]], %{{.*}})
// CHECK: ^bb0(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>): // CHECK: ^bb0(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>):
// CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG0]], %[[ARG1]]) {comparison_direction = "GE"} // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG0]], %[[ARG1]]) {comparison_direction = "GE"}
// CHECK: "mhlo.return"(%[[COMPARE]]) : (tensor<i1>) -> () // CHECK: "mhlo.return"(%[[COMPARE]]) : (tensor<i1>) -> ()
@ -78,7 +81,7 @@ HloModule SelectAndScatter
ENTRY main () -> f32[6] { ENTRY main () -> f32[6] {
%operand = f32[6]{0} parameter(0) %operand = f32[6]{0} parameter(0)
%source = f32[2]{0} parameter(1) %source = f32[2]{0} parameter(1)
%init = f32[] parameter(2) %init = f32[] constant(0)
ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32 ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32
} }
@ -101,4 +104,20 @@ ENTRY main {
s32[] %static), s32[] %static),
custom_call_target="SliceToDynamic", custom_call_target="SliceToDynamic",
backend_config="" backend_config=""
} }
// -----
HloModule Cholesky
// CHECK-LABEL: func @main
// CHECK: "lmhlo_gpu.cholesky"
// CHECK-SAME: is_lower = true
ENTRY main {
%param = f32[3,3] parameter(0)
ROOT %custom-call = (f32[3,3], f32[3], s32[]) custom-call(f32[3,3] %param),
custom_call_target="__cusolver$cholesky",
operand_layout_constraints={f32[3,3]},
backend_config="{\"lower\":true}"
}

View File

@ -374,3 +374,23 @@ func @main(%arg0: tuple<tuple<tensor<f32>>, tensor<f32>>, %arg1: tuple<tensor<f3
return %result : tuple<tensor<f32>, tensor<f32>, tensor<f32>> return %result : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
} }
// -----
// CHECK-LABEL: func @main
// CHECK: "lmhlo.reduce"({{.*}}) ( {
// CHECK: ^bb0(%[[VAL1:.*]]: tensor<f32>, %[[VAL2:.*]]: tensor<i32>, %[[VAL3:.*]]: tensor<f32>, %[[VAL4:.*]]: tensor<i32>): // no predecessors
// CHECK: %[[VAL5:.*]] = mhlo.maximum %[[VAL1]], %[[VAL3]] : tensor<f32>
// CHECK: %[[VAL6:.*]] = mhlo.maximum %[[VAL2]], %[[VAL4:.*]] : tensor<i32>
// CHECK: %[[VAL7:.*]] = "mhlo.tuple"(%[[VAL5]], %[[VAL6:.*]]) : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
// CHECK: "mhlo.return"(%[[VAL7:.*]]) : (tuple<tensor<f32>, tensor<i32>>) -> ()
// CHECK: })
func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
%result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
%fmax = "mhlo.maximum"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
%imax = "mhlo.maximum"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
"mhlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
}

View File

@ -13,10 +13,8 @@
// CHECK-LABEL: func @add // CHECK-LABEL: func @add
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %[[SUM0:.*]] = mhlo.add %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: %[[SUM0:.*]] = mhlo.add %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: %[[SUM1:.*]] = mhlo.add %[[SUM0]], %arg0 : tensor<2xi32> // CHECK-NEXT: return %[[SUM0]] : tensor<2xi32>
// CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> %1 = "tf.AddV2"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %1: tensor<2xi32> return %1: tensor<2xi32>
} }
@ -27,7 +25,7 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1 // CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0: tensor<1x2xi32> return %0: tensor<1x2xi32>
} }
@ -37,7 +35,7 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
// CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1 // CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
return %0: tensor<4x4x4x4xi32> return %0: tensor<4x4x4x4xi32>
} }
@ -55,7 +53,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
// CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32> // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32>
// CHECK-NEXT: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: shape.assuming_yield %[[RESULT]]
%0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32> %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0: tensor<?x?xi32> return %0: tensor<?x?xi32>
} }
@ -64,7 +62,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
func @broadcast_add_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { func @broadcast_add_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: tf.Add // CHECK: tf.Add
// CHLO: chlo.broadcast_add %arg0, %arg1 // CHLO: chlo.broadcast_add %arg0, %arg1
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32> return %0: tensor<*xi32>
} }
@ -264,6 +262,13 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1>
return %0: tensor<*xi1> return %0: tensor<*xi1>
} }
// CHECK-LABEL: func @equal_unsupported_type
func @equal_unsupported_type(%arg0: tensor<*x!tf.string>, %arg1: tensor<*x!tf.string>) -> tensor<*xi1> {
// CHECK: "tf.Equal"
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*x!tf.string>, tensor<*x!tf.string>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @notequal // CHECK-LABEL: func @notequal
func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"}

View File

@ -26,7 +26,7 @@ func @tf_unknown_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// ----- // -----
func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %0 = "tf.AddV2"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32> return %0: tensor<2xi32>
} }
@ -38,7 +38,7 @@ func @tf_unknown_known_mix(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// expected-error@+1 {{'tf.OpA' op is not legalizable}} // expected-error@+1 {{'tf.OpA' op is not legalizable}}
%0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%2 = "tf.Add"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %2 = "tf.AddV2"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %2: tensor<2xi32> return %2: tensor<2xi32>
} }

View File

@ -3984,31 +3984,35 @@ func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) {
// tf.Unpack legalization // tf.Unpack legalization
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TODO(b/156340000): Re-enable when fixed. // CHECK-LABEL: @unpack
// // C-HECK-LABEL: @unpack func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) {
// func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
// // C-HECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> // CHECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
// // C-HECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
// // C-HECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> // CHECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
// // C-HECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
// // C-HECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> // CHECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
// // C-HECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
// %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
// // return %[[RES1]], %[[RES2]], %[[RES3]] // return %[[RES1]], %[[RES2]], %[[RES3]]
// return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32> return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>
// } }
// // C-HECK-LABEL: @unpack_dynamic // CHECK-LABEL: @unpack_dynamic
// func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
// // C-HECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32>
// // C-HECK: "mhlo.reshape"(%[[SLICE1]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32>
// // C-HECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32>
// // C-HECK: "mhlo.reshape"(%[[SLICE2]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32>
// %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) // CHECK: tf.Unpack
// return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
// } return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
}
// CHECK-LABEL: @unpack_unranked
func @unpack_unranked(%input: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
// CHECK: tf.Unpack
%0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization // tf.UnsortedSegment{Max|Min|Prod|Sum} legalization

View File

@ -3311,7 +3311,7 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
auto input_val = GetScalarConstOfType(begin_element_ty, loc, auto input_val = GetScalarConstOfType(begin_element_ty, loc,
input_shape[d], &rewriter); input_shape[d], &rewriter);
auto wrapped_index = auto wrapped_index =
rewriter.create<TF::AddOp>(loc, input_val, reshaped_index); rewriter.create<TF::AddV2Op>(loc, input_val, reshaped_index);
auto final_index = rewriter.create<SelectOp>( auto final_index = rewriter.create<SelectOp>(
loc, type, index_negative, wrapped_index, reshaped_index); loc, type, index_negative, wrapped_index, reshaped_index);
slice_begin_indices.push_back(final_index); slice_begin_indices.push_back(final_index);
@ -4808,7 +4808,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
LogicalResult matchAndRewrite(TF::UnpackOp op, LogicalResult matchAndRewrite(TF::UnpackOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto value_type = op.value().getType().cast<RankedTensorType>(); auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!value_type) return failure(); if (!value_type) return failure();
int64_t value_rank = value_type.getRank(); int64_t value_rank = value_type.getRank();
@ -4820,7 +4820,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
auto end_indices = llvm::to_vector<4>(value_type.getShape()); auto end_indices = llvm::to_vector<4>(value_type.getShape());
SmallVector<int64_t, 4> strides(value_rank, 1); SmallVector<int64_t, 4> strides(value_rank, 1);
// All HLO slice+reshape results used to replace the original tf.Unpack op. // All HLO slice+squeeze results used to replace the original tf.Unpack op.
SmallVector<Value, 4> results; SmallVector<Value, 4> results;
results.reserve(op.getNumResults()); results.reserve(op.getNumResults());
@ -4833,9 +4833,10 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
GetI64ElementsAttr(end_indices, &rewriter), GetI64ElementsAttr(end_indices, &rewriter),
GetI64ElementsAttr(strides, &rewriter)); GetI64ElementsAttr(strides, &rewriter));
// Reshape to drop the axis dimension. // Reshape to drop the axis dimension.
auto reshape_op = rewriter.create<mhlo::ReshapeOp>( auto result =
op.getLoc(), op.getType(i), slice_op); rewriter.create<TF::SqueezeOp>(op.getLoc(), op.getType(i), slice_op,
results.push_back(reshape_op); rewriter.getI64ArrayAttr(op.axis()));
results.push_back(result);
} }
rewriter.replaceOp(op, results); rewriter.replaceOp(op, results);

View File

@ -89,8 +89,7 @@ class DirectBinaryPat<Op FromOp, Op ToOp>
: Pat<(FromOp AnyTensor:$l, AnyTensor:$r), : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
(ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp], foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp],
[TF_AddV2Op, HLOClient_BroadcastAddOp],
[TF_DivOp, HLOClient_BroadcastDivOp], [TF_DivOp, HLOClient_BroadcastDivOp],
[TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp], [TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp],
[TF_MaximumOp, HLOClient_BroadcastMaxOp], [TF_MaximumOp, HLOClient_BroadcastMaxOp],
@ -225,7 +224,7 @@ class EqualityPat<Op FromOp, StrEnumAttrCase direction>
(HLOClient_BroadcastCompareOp (HLOClient_BroadcastCompareOp
$l, $r, (BinBroadcastDimensions $l, $r), direction, $l, $r, (BinBroadcastDimensions $l, $r), direction,
(HLO_DEFAULT_COMPARISON_TYPE)), (HLO_DEFAULT_COMPARISON_TYPE)),
[(AreBroadcastCompatible $l, $r)]>; [(AreBroadcastCompatible $l, $r), (HLO_Tensor $l)]>;
def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>; def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>;
def : EqualityPat<TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE>; def : EqualityPat<TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE>;

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
#include <climits>
#include <memory> #include <memory>
#include <tuple> #include <tuple>
@ -38,6 +39,7 @@ limitations under the License.
#include "mlir/Pass/PassOptions.h" // from @llvm-project #include "mlir/Pass/PassOptions.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
@ -46,18 +48,21 @@ limitations under the License.
#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
using xla::BufferAllocation; using xla::BufferAllocation;
using xla::BufferAssignment; using xla::BufferAssignment;
using xla::HloComputation; using xla::HloComputation;
using xla::HloCustomCallInstruction;
using xla::HloInstruction; using xla::HloInstruction;
using xla::HloModule; using xla::HloModule;
using xla::HloModuleProto; using xla::HloModuleProto;
@ -140,8 +145,9 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
class XlaHloToLhloPass class XlaHloToLhloPass
: public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> { : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect, registry
mlir::lmhlo::LmhloDialect>(); .insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
mlir::lmhlo::LmhloDialect, mlir::lmhlo_gpu::LmhloGpuDialect>();
} }
public: public:
@ -274,6 +280,10 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
return EmitSelectAndScatterOp(instr); return EmitSelectAndScatterOp(instr);
case HloOpcode::kCustomCall: case HloOpcode::kCustomCall:
return EmitCustomCallOp(instr); return EmitCustomCallOp(instr);
case HloOpcode::kConstant:
return EmitConstant(instr);
case HloOpcode::kReduce:
return EmitReduceOp(instr);
default: default:
llvm::errs() << instr->ToString(); llvm::errs() << instr->ToString();
return tensorflow::errors::Internal( return tensorflow::errors::Internal(
@ -485,13 +495,18 @@ StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
return select_and_scatter; return select_and_scatter;
} }
StatusOr<lmhlo::CustomCallOp> LhloDialectEmitter::EmitCustomCallOp( StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
HloInstruction* instr) { HloInstruction* instr) {
auto* custom_call_instr = ::xla::Cast<::xla::HloCustomCallInstruction>(instr);
if (xla::gpu::IsCustomCallToCusolver(*instr)) {
return EmitCholesky(custom_call_instr);
}
size_t num_arguments, num_results; size_t num_arguments, num_results;
TF_ASSIGN_OR_RETURN(auto custom_call, TF_ASSIGN_OR_RETURN(auto custom_call,
CreateOpWithoutAttrs<lmhlo::CustomCallOp>( CreateOpWithoutAttrs<lmhlo::CustomCallOp>(
instr, num_arguments, num_results)); instr, num_arguments, num_results));
auto* custom_call_instr = ::xla::Cast<::xla::HloCustomCallInstruction>(instr);
custom_call.call_target_nameAttr( custom_call.call_target_nameAttr(
builder_.getStringAttr(custom_call_instr->custom_call_target())); builder_.getStringAttr(custom_call_instr->custom_call_target()));
custom_call.backend_configAttr( custom_call.backend_configAttr(
@ -500,12 +515,102 @@ StatusOr<lmhlo::CustomCallOp> LhloDialectEmitter::EmitCustomCallOp(
static_cast<int32_t>(num_results)}; static_cast<int32_t>(num_results)};
custom_call.setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(), custom_call.setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
builder_.getI32VectorAttr(segments)); builder_.getI32VectorAttr(segments));
return custom_call; return custom_call.getOperation();
}
StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
HloCustomCallInstruction* custom_call) {
TF_ASSIGN_OR_RETURN(auto cholesky_op,
CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
custom_call->backend_config<xla::CholeskyOptions>());
cholesky_op.is_lowerAttr(builder_.getBoolAttr(options.lower()));
return cholesky_op;
}
// Convert an XLA HLO constant to a global_memref + get_global_memref pair.
StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant(
const HloInstruction* instr) {
// Insert a global_memref in the module.
Location loc = getLocation(instr);
auto const_instr = ::xla::Cast<::xla::HloConstantInstruction>(instr);
TF_RET_CHECK(const_instr->shape().IsArray() &&
const_instr->shape().is_static());
TF_ASSIGN_OR_RETURN(Type type, ::xla::ConvertShapeToType<MemRefType>(
const_instr->shape(), builder_));
auto memref_type = type.dyn_cast<MemRefType>();
TF_RET_CHECK(memref_type != nullptr);
TF_ASSIGN_OR_RETURN(
DenseElementsAttr initial_value,
CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_));
std::string constant_name = ::xla::llvm_ir::ConstantHloToGlobalName(*instr);
// Insert the global memref at the top level.
{
OpBuilder::InsertionGuard guard(builder_);
builder_.clearInsertionPoint();
auto global_var = builder_.create<GlobalMemrefOp>(
loc, constant_name, builder_.getStringAttr("private"),
TypeAttr::get(memref_type), initial_value, true);
SymbolTable(module_).insert(global_var);
global_var.getOperation()->moveBefore(&module_.front());
}
auto get_global_memref =
builder_.create<GetGlobalMemrefOp>(loc, memref_type, constant_name);
// For operations that do not fold this constant value in their codegen, we
// still need to materialize it into a buffer. Since buffer allocation is
// already done, annotate the get_global_memref with the information to get to
// the allocated buffer slice for this constant if need be.
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
assignment_.GetUniqueTopLevelSlice(instr));
get_global_memref.setAttr("lmhlo.alloc",
builder_.getIndexAttr(slice.index()));
get_global_memref.setAttr("lmhlo.slice_offset",
builder_.getI64IntegerAttr(slice.offset()));
get_global_memref.setAttr("lmhlo.slice_size",
builder_.getI64IntegerAttr(slice.size()));
// Update the cache to remember this value.
auto& cached_value = slices_[std::make_pair(instr, ::xla::ShapeIndex())];
TF_RET_CHECK(cached_value == nullptr);
cached_value = get_global_memref;
return get_global_memref;
}
StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp(
HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto reduce_op,
CreateOpWithoutAttrs<lmhlo::ReduceOp>(instr));
auto* reduce = ::xla::Cast<::xla::HloReduceInstruction>(instr);
std::vector<int64_t> dimensions(reduce->dimensions().begin(),
reduce->dimensions().end());
reduce_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
*instr->called_computations()[0], &reduce_op.body(), &builder_));
return reduce_op;
} }
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView( StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
const ::xla::ShapeIndex& shape_index) { const ::xla::ShapeIndex& shape_index) {
// Cache generated ViewOp and StaticMemRefCastOp by (instruction,
// shape_index).
auto& cached_value = slices_[std::make_pair(instr, shape_index)];
if (cached_value) {
return cached_value;
}
if (instr->IsConstant() && shape_index.empty()) {
TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr));
return cached_value = constant_memref;
}
// If the shape happens to have dynamic dimensions, create the memref using // If the shape happens to have dynamic dimensions, create the memref using
// the underlying static shape. // the underlying static shape.
// TODO(jurahul): Revisit this when we can model memrefs with dynamic shape // TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
@ -518,7 +623,7 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
assignment_.GetUniqueSlice(instr, shape_index)); assignment_.GetUniqueSlice(instr, shape_index));
Value alloc = allocations_[slice.allocation()]; Value alloc = allocations_[slice.allocation()];
if (alloc.getType() == out_type && slice.offset() == 0) { if (alloc.getType() == out_type && slice.offset() == 0) {
return alloc; return cached_value = alloc;
} }
auto out_memref_type = out_type.dyn_cast<MemRefType>(); auto out_memref_type = out_type.dyn_cast<MemRefType>();
@ -527,13 +632,6 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
"Expected memref type when creating a view for leaf type of a " "Expected memref type when creating a view for leaf type of a "
"tuple."); "tuple.");
// Cache generated ViewOp and StaticMemRefCastOp by (instruction,
// shape_index).
auto& cached_value = slices_[std::make_pair(instr, shape_index)];
if (cached_value) {
return cached_value;
}
Value byte_shift = Value byte_shift =
builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset()); builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset());
@ -695,8 +793,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
Status HloToLhloModule(const BufferAssignment& assignment, Status HloToLhloModule(const BufferAssignment& assignment,
const HloModule& hlo_module, ModuleOp module) { const HloModule& hlo_module, ModuleOp module) {
module.getContext() module.getContext()
->loadDialect<StandardOpsDialect, mhlo::MhloDialect, ->loadDialect<StandardOpsDialect, mhlo::MhloDialect, lmhlo::LmhloDialect,
lmhlo::LmhloDialect>(); lmhlo_gpu::LmhloGpuDialect>();
HloComputation* computation = hlo_module.entry_computation(); HloComputation* computation = hlo_module.entry_computation();
LhloDialectEmitter emitter(assignment, *computation, module); LhloDialectEmitter emitter(assignment, *computation, module);

View File

@ -16,13 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
namespace mlir { namespace mlir {
@ -55,8 +57,18 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(::xla::HloInstruction* instr); ::xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp( ::xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp(
::xla::HloInstruction* instr); ::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::CustomCallOp> EmitCustomCallOp(
::xla::HloInstruction* instr); ::xla::StatusOr<Operation*> EmitCustomCallOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky(
::xla::HloCustomCallInstruction* custom_call);
::xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(::xla::HloInstruction* instr);
::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
::xla::HloInstruction* instr) {
return EmitConstant(static_cast<const ::xla::HloInstruction*>(instr));
}
::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
const ::xla::HloInstruction* instr);
::xla::Status CreateOperands(::xla::HloInstruction* instr, ::xla::Status CreateOperands(::xla::HloInstruction* instr,
SmallVectorImpl<Value>& operands, SmallVectorImpl<Value>& operands,
@ -122,7 +134,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
OpBuilder* b, Location loc); OpBuilder* b, Location loc);
// Return an MLIR location for an HLO instruction. // Return an MLIR location for an HLO instruction.
Location getLocation(::xla::HloInstruction* inst) { Location getLocation(const ::xla::HloInstruction* inst) {
return NameLoc::get(builder_.getIdentifier(inst->name()), return NameLoc::get(builder_.getIdentifier(inst->name()),
builder_.getContext()); builder_.getContext());
} }
@ -180,7 +192,7 @@ tensorflow::Status HloToLhloModule(const ::xla::BufferAssignment& assignment,
ModuleOp module); ModuleOp module);
OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input, OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
mlir::MLIRContext* context); MLIRContext* context);
} // namespace mlir } // namespace mlir

View File

@ -85,6 +85,9 @@ namespace tensorflow {
namespace tensorrt { namespace tensorrt {
namespace convert { namespace convert {
using absl::StrAppend;
using absl::StrCat;
bool IsEngineInput(absl::string_view name) { bool IsEngineInput(absl::string_view name) {
return absl::StartsWith(name, IONamePrefixes::kInputPHName); return absl::StartsWith(name, IONamePrefixes::kInputPHName);
} }
@ -92,47 +95,6 @@ bool IsEngineOutput(absl::string_view name) {
return absl::StartsWith(name, IONamePrefixes::kOutputPHName); return absl::StartsWith(name, IONamePrefixes::kOutputPHName);
} }
using absl::StrAppend;
using absl::StrCat;
inline Status TfDataTypeToTrt(DataType tf_dtype,
nvinfer1::DataType* trt_dtype) {
switch (tf_dtype) {
case DataType::DT_FLOAT:
*trt_dtype = nvinfer1::DataType::kFLOAT;
break;
case DataType::DT_HALF:
*trt_dtype = nvinfer1::DataType::kHALF;
break;
case DataType::DT_INT32:
*trt_dtype = nvinfer1::DataType::kINT32;
break;
default:
return errors::InvalidArgument("Unsupported data type ",
DataTypeString(tf_dtype));
}
return Status::OK();
}
inline Status TrtDataTypeToTf(nvinfer1::DataType trt_dtype,
DataType* tf_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
*tf_dtype = DataType::DT_FLOAT;
break;
case nvinfer1::DataType::kHALF:
*tf_dtype = DataType::DT_HALF;
break;
case nvinfer1::DataType::kINT32:
*tf_dtype = DataType::DT_INT32;
break;
default:
return errors::InvalidArgument("Unsupported data type ",
DebugString(trt_dtype));
}
return Status::OK();
}
class TFAttrs { class TFAttrs {
public: public:
explicit TFAttrs(const NodeDef& tf_node) { explicit TFAttrs(const NodeDef& tf_node) {
@ -182,7 +144,7 @@ std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
template <> template <>
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const { nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
TF_CHECK_OK(TfDataTypeToTrt(this->at(key)->type(), &trt_dtype)); TF_CHECK_OK(TfTypeToTrtType(this->at(key)->type(), &trt_dtype));
return trt_dtype; return trt_dtype;
} }
@ -271,7 +233,7 @@ Status ValidateTensorProperties(const string& producer_node_type,
nvinfer1::DataType* trt_dtype, nvinfer1::DataType* trt_dtype,
nvinfer1::Dims* trt_dims, int* batch_size) { nvinfer1::Dims* trt_dims, int* batch_size) {
// Convert data type. // Convert data type.
TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, trt_dtype)); TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, trt_dtype));
// Convert shape. // Convert shape.
if (shape.dims() < 0) { if (shape.dims() < 0) {
@ -512,7 +474,7 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
TFAttrs attrs(params->node_def); TFAttrs attrs(params->node_def);
if (attrs.count(dtype_attr_name)) { if (attrs.count(dtype_attr_name)) {
DataType dtype = attrs.get<DataType>(dtype_attr_name); DataType dtype = attrs.get<DataType>(dtype_attr_name);
TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_type)); TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, &trt_type));
} }
// In order to be broadcastable, the number of dims has to match. // In order to be broadcastable, the number of dims has to match.
@ -1091,7 +1053,7 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype,
DataType tf_dtype; DataType tf_dtype;
// TODO(laigd): make it return a status. // TODO(laigd): make it return a status.
TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape)); TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape));
TF_CHECK_OK(TrtDataTypeToTf(trt_dtype, &tf_dtype)); TF_CHECK_OK(TrtTypeToTfType(trt_dtype, &tf_dtype));
// TODO(jie): check weights size_bytes. 0 means type error // TODO(jie): check weights size_bytes. 0 means type error
Tensor tensor(tf_dtype, shape); Tensor tensor(tf_dtype, shape);
TRT_ShapedWeights weights(trt_dtype, dims, tensor); TRT_ShapedWeights weights(trt_dtype, dims, tensor);
@ -2621,6 +2583,7 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input,
nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation( nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation(
const_cast<nvinfer1::ITensor* const*>(concat_inputs.data()), const_cast<nvinfer1::ITensor* const*>(concat_inputs.data()),
concat_inputs.size()); concat_inputs.size());
SetLayerName(concat_layer, params->node_def, "concat", op_instance);
concat_layer->setAxis(0); concat_layer->setAxis(0);
nvinfer1::ITensor* new_shape = concat_layer->getOutput(0); nvinfer1::ITensor* new_shape = concat_layer->getOutput(0);
// Reshape input using new shape // Reshape input using new shape
@ -4219,7 +4182,7 @@ Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
// Verify that the dtype is supported by TensorRT. Otherwise, return an error. // Verify that the dtype is supported by TensorRT. Otherwise, return an error.
nvinfer1::DataType trt_dtype; nvinfer1::DataType trt_dtype;
TF_RETURN_IF_ERROR(TfDataTypeToTrt(converted_dtype, &trt_dtype)); TF_RETURN_IF_ERROR(TfTypeToTrtType(converted_dtype, &trt_dtype));
if (tensor.NumElements() == 0) { if (tensor.NumElements() == 0) {
// Return empty weights. // Return empty weights.
@ -6244,7 +6207,7 @@ Status ConvertGraphDefToEngine(
TFAttrs attrs(node_def); TFAttrs attrs(node_def);
DataType tf_dtype = attrs.get<DataType>("T"); DataType tf_dtype = attrs.get<DataType>("T");
nvinfer1::DataType trt_dtype; nvinfer1::DataType trt_dtype;
TF_RETURN_IF_ERROR(TfDataTypeToTrt(tf_dtype, &trt_dtype)); TF_RETURN_IF_ERROR(TfTypeToTrtType(tf_dtype, &trt_dtype));
if (output_tensors.size() <= slot_number) { if (output_tensors.size() <= slot_number) {
output_tensors.resize(slot_number + 1); output_tensors.resize(slot_number + 1);
} }

View File

@ -135,20 +135,6 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) {
return os; return os;
} }
nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) {
nvinfer1::DataType trt_type;
Status status = TfTypeToTrtType(tf_type, &trt_type);
EXPECT_EQ(status, Status::OK());
return trt_type;
}
DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) {
DataType tf_type;
Status status = TrtTypeToTfType(trt_type, &tf_type);
EXPECT_EQ(status, Status::OK());
return tf_type;
}
NodeDef MakeNodeDef(const string& name, const string& op, NodeDef MakeNodeDef(const string& name, const string& op,
const std::vector<string>& inputs, const std::vector<string>& inputs,
const std::map<string, AttrValue> attrs = {}) { const std::map<string, AttrValue> attrs = {}) {
@ -1048,8 +1034,10 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) {
template <typename T> template <typename T>
void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) { void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
TRT_ShapedWeights weights = weight_store->GetTempWeights( nvinfer1::DataType trt_type;
TfDataTypeToTrt(DataTypeToEnum<T>::v()), GetTestDims({2, 3})); TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &trt_type));
TRT_ShapedWeights weights =
weight_store->GetTempWeights(trt_type, GetTestDims({2, 3}));
const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)}; const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
memcpy(weights.GetValues(), values.data(), weights.size_bytes()); memcpy(weights.GetValues(), values.data(), weights.size_bytes());
@ -1445,7 +1433,8 @@ class OpConverterTest : public ::testing::Test {
ASSERT_NE(-1, input_index); ASSERT_NE(-1, input_index);
const nvinfer1::DataType trt_dtype = const nvinfer1::DataType trt_dtype =
engine_->getBindingDataType(input_index); engine_->getBindingDataType(input_index);
const DataType tf_type = TrtDataTypeToTf(trt_dtype); DataType tf_type;
TF_ASSERT_OK(TrtTypeToTfType(trt_dtype, &tf_type));
ASSERT_EQ(data.tensor.dtype(), tf_type) ASSERT_EQ(data.tensor.dtype(), tf_type)
<< DataTypeString(data.tensor.dtype()) << " vs. " << DataTypeString(data.tensor.dtype()) << " vs. "
<< DataTypeString(tf_type); << DataTypeString(tf_type);
@ -1457,8 +1446,9 @@ class OpConverterTest : public ::testing::Test {
// Mark the output tensor as TRT engine output. // Mark the output tensor as TRT engine output.
std::vector<Converter::EngineOutputInfo> output_info; std::vector<Converter::EngineOutputInfo> output_info;
for (const auto& data : *output_data) { for (const auto& data : *output_data) {
output_info.push_back( nvinfer1::DataType trt_type;
{data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())}); TF_RETURN_IF_ERROR(TfTypeToTrtType(data.tensor.dtype(), &trt_type));
output_info.push_back({data.name, data.name, trt_type});
} }
TF_RETURN_IF_ERROR(converter_->RenameAndMarkOutputTensors(output_info)); TF_RETURN_IF_ERROR(converter_->RenameAndMarkOutputTensors(output_info));
@ -1519,7 +1509,8 @@ class OpConverterTest : public ::testing::Test {
const string& name, const std::vector<int32>& dims, const string& name, const std::vector<int32>& dims,
nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT, nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT,
Status add_input_status = Status::OK()) { Status add_input_status = Status::OK()) {
DataType tf_type = TrtDataTypeToTf(trt_type); DataType tf_type;
TF_ASSERT_OK(TrtTypeToTfType(trt_type, &tf_type));
ops::Placeholder::Attrs attrs; ops::Placeholder::Attrs attrs;
TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_)); TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_));
@ -1566,7 +1557,8 @@ class OpConverterTest : public ::testing::Test {
node_inputs_[name] = ops::Const(scope_.WithOpName(name), t); node_inputs_[name] = ops::Const(scope_.WithOpName(name), t);
// Add weights for conversion. // Add weights for conversion.
const nvinfer1::DataType dtype = TfDataTypeToTrt(DataTypeToEnum<T>::v()); nvinfer1::DataType dtype;
TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &dtype));
const nvinfer1::Dims trt_dims = GetTestDims(dims); const nvinfer1::Dims trt_dims = GetTestDims(dims);
const int64_t num_elements = TrtWeightDimsNumElements(trt_dims); const int64_t num_elements = TrtWeightDimsNumElements(trt_dims);
QCHECK_EQ(num_elements, values.size()) QCHECK_EQ(num_elements, values.size())
@ -1800,8 +1792,9 @@ class ParameterizedOpConverterTestBase
partial_shape = dims; partial_shape = dims;
} }
} }
AddTestTensorWithTFDims(name, partial_shape, TfDataTypeToTrt(tf_type), nvinfer1::DataType trt_type;
add_input_status); TF_ASSERT_OK(TfTypeToTrtType(tf_type, &trt_type));
AddTestTensorWithTFDims(name, partial_shape, trt_type, add_input_status);
if (!values.empty()) { if (!values.empty()) {
VLOG(2) << "Adding test tensor: " << name << " " VLOG(2) << "Adding test tensor: " << name << " "
<< DataTypeString(tf_type); << DataTypeString(tf_type);
@ -2032,7 +2025,7 @@ TEST_F(OpConverterTest, ConvertConst) {
Reset(); Reset();
NodeDef node_def = MakeConstNodeDef<double>("my_const", {}); NodeDef node_def = MakeConstNodeDef<double>("my_const", {});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Unsupported data type double"); "Unsupported tensorflow data type double");
} }
{ {
Reset(); Reset();
@ -2805,8 +2798,9 @@ void TestAddN(OpConverterTest* test) {
test->Reset(); test->Reset();
DataVec input_data; DataVec input_data;
for (const auto name : {"inp1", "inp2", "inp3"}) { for (const auto name : {"inp1", "inp2", "inp3"}) {
test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/2, nvinfer1::DataType trt_type;
TfDataTypeToTrt(dtype)); TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/2, trt_type);
input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2), input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2),
CType(3), CType(4)})}); CType(3), CType(4)})});
} }
@ -2828,8 +2822,9 @@ void TestAddN(OpConverterTest* test) {
test->Reset(); test->Reset();
DataVec input_data; DataVec input_data;
for (const auto name : {"inp1", "inp2"}) { for (const auto name : {"inp1", "inp2"}) {
test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/1, nvinfer1::DataType trt_type;
TfDataTypeToTrt(dtype)); TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/1, trt_type);
input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2)})}); input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2)})});
} }
test->AddTestWeights("inp3", /*dims=*/{1, 1, 2}, test->AddTestWeights("inp3", /*dims=*/{1, 1, 2},
@ -4252,8 +4247,9 @@ TEST_P(OpConverterTest1, ConvertConv2D) {
Reset(); Reset();
NodeDef node_def = get_conv2d_nodedef(); NodeDef node_def = get_conv2d_nodedef();
// Channel dim unknown, should fail. // Channel dim unknown, should fail.
AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, nvinfer1::DataType trt_type;
TfDataTypeToTrt(tf_type_)); TF_ASSERT_OK(TfTypeToTrtType(tf_type_, &trt_type));
AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, trt_type);
AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1}); AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1});
RunValidationAndConversion( RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT, node_def, error::INVALID_ARGUMENT,
@ -5018,8 +5014,9 @@ TEST_F(OpConverterTest, ConvertTopK) {
{ {
// K is a tensor, should fail. // K is a tensor, should fail.
Reset(); Reset();
AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1, nvinfer1::DataType trt_type;
/*trt_dtype=*/TfDataTypeToTrt(dtype)); TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1, trt_type);
AddTestTensor("weights", {2}); AddTestTensor("weights", {2});
RunValidationAndConversion( RunValidationAndConversion(
node_def, error::UNIMPLEMENTED, node_def, error::UNIMPLEMENTED,
@ -5590,8 +5587,10 @@ void TestConvertConcat(OpConverterTest* test) {
NodeDef node_def = get_concat_nodedef(dtype, num_inputs); NodeDef node_def = get_concat_nodedef(dtype, num_inputs);
// Create inputs. // Create inputs.
for (int j = 0; j < num_inputs; ++j) { for (int j = 0; j < num_inputs; ++j) {
nvinfer1::DataType trt_type;
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
test->AddTestTensor(StrCat("values_", j), ok_params[i].input_shapes[j], 1, test->AddTestTensor(StrCat("values_", j), ok_params[i].input_shapes[j], 1,
TfDataTypeToTrt(dtype)); trt_type);
} }
test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis}); test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
test->RunValidationAndConversion(node_def); test->RunValidationAndConversion(node_def);
@ -5752,8 +5751,9 @@ void TestConvertSplit(OpConverterTest* test) {
NodeDef node_def = get_split_nodedef(dtype, ok_params[i].num_split); NodeDef node_def = get_split_nodedef(dtype, ok_params[i].num_split);
// Create inputs. // Create inputs.
test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis}); test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
test->AddTestTensor("value", ok_params[i].input_shape, 1, nvinfer1::DataType trt_type;
TfDataTypeToTrt(dtype)); TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
test->AddTestTensor("value", ok_params[i].input_shape, 1, trt_type);
// Convert. // Convert.
test->RunValidationAndConversion(node_def); test->RunValidationAndConversion(node_def);
@ -5929,8 +5929,9 @@ void TestConvertUnpack(OpConverterTest* test) {
NodeDef node_def = NodeDef node_def =
get_unpack_nodedef(dtype, ok_params[i].num, ok_params[i].axis); get_unpack_nodedef(dtype, ok_params[i].num, ok_params[i].axis);
// Create inputs. // Create inputs.
test->AddTestTensor("value", ok_params[i].input_shape, 1, nvinfer1::DataType trt_type;
TfDataTypeToTrt(dtype)); TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
test->AddTestTensor("value", ok_params[i].input_shape, 1, trt_type);
// Convert. // Convert.
test->RunValidationAndConversion(node_def); test->RunValidationAndConversion(node_def);
@ -6272,8 +6273,10 @@ void TestConvertArgMinMax(OpConverterTest* test) {
NodeDef node_def = GetArgMinMaxNodeDef<OpType>(dtype, DT_INT32); NodeDef node_def = GetArgMinMaxNodeDef<OpType>(dtype, DT_INT32);
// Create inputs. // Create inputs.
nvinfer1::DataType trt_type;
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
test->AddTestTensor("input", params[i].input_shape, /*batch_size=*/1, test->AddTestTensor("input", params[i].input_shape, /*batch_size=*/1,
/*trt_dtype=*/TfDataTypeToTrt(dtype)); /*trt_dtype=*/trt_type);
test->AddTestWeights<int32>("dimension", {1}, {params[i].axis}); test->AddTestWeights<int32>("dimension", {1}, {params[i].axis});
test->RunValidationAndConversion(node_def); test->RunValidationAndConversion(node_def);
@ -6374,8 +6377,9 @@ void TestConvertDepthSpaceShuffle(
NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>( NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>(
dtype, params[i].block_size, params[i].data_format); dtype, params[i].block_size, params[i].data_format);
test->AddTestTensor("input", params[i].input_dims, 1, nvinfer1::DataType trt_type;
TfDataTypeToTrt(dtype)); TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
test->AddTestTensor("input", params[i].input_dims, 1, trt_type);
test->RunValidationAndConversion(node_def); test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output; TRT_TensorOrWeights output;
@ -6648,7 +6652,9 @@ void TestConvertClipByValue(OpConverterTest* test) {
test->Reset(); test->Reset();
NodeDef node_def = GetClipByValueNodeDef(dtype); NodeDef node_def = GetClipByValueNodeDef(dtype);
test->AddTestTensor("t", params[i].dims, 1, TfDataTypeToTrt(dtype)); nvinfer1::DataType trt_type;
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
test->AddTestTensor("t", params[i].dims, 1, trt_type);
test->AddTestWeights<CType>("clip_value_min", {1}, test->AddTestWeights<CType>("clip_value_min", {1},
{params[i].clip_value_min}); {params[i].clip_value_min});
test->AddTestWeights<CType>("clip_value_max", {1}, test->AddTestWeights<CType>("clip_value_max", {1},
@ -6848,9 +6854,11 @@ void TestConvertResize(OpConverterTest* test) {
// Create resize node. // Create resize node.
NodeDef node_def = NodeDef node_def =
MakeResizeNodeDef<OpType>("my_resize", dtype, params[i].align_corners); MakeResizeNodeDef<OpType>("my_resize", dtype, params[i].align_corners);
nvinfer1::DataType trt_type;
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
// Create input tensor // Create input tensor
test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1, test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1,
/*trt_dtype=*/TfDataTypeToTrt(dtype)); /*trt_dtype=*/trt_type);
// Create output size. // Create output size.
test->AddTestWeights<int32>("size", {2}, params[i].output_resize_dims); test->AddTestWeights<int32>("size", {2}, params[i].output_resize_dims);
@ -6949,8 +6957,10 @@ void TestConvertPad(OpConverterTest* test) {
// Create pad node. // Create pad node.
NodeDef node_def = MakePadNodeDef("my_pad", dtype); NodeDef node_def = MakePadNodeDef("my_pad", dtype);
// Create input tensor // Create input tensor
nvinfer1::DataType trt_type;
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1, test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1,
/*trt_dtype=*/TfDataTypeToTrt(dtype)); /*trt_dtype=*/trt_type);
// Create output size. // Create output size.
test->AddTestWeights<int32>("padding", params[i].pad_dims, test->AddTestWeights<int32>("padding", params[i].pad_dims,
{0, 0, 1, 0, 0, 1, 0, 0}); {0, 0, 1, 0, 0, 1, 0, 0});

View File

@ -198,7 +198,8 @@ Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
*trt_type = nvinfer1::DataType::kINT32; *trt_type = nvinfer1::DataType::kINT32;
break; break;
default: default:
return errors::Internal("Unsupported tensorflow type"); return errors::InvalidArgument("Unsupported tensorflow data type ",
DataTypeString(tf_type));
} }
return Status::OK(); return Status::OK();
} }
@ -215,7 +216,7 @@ Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
*tf_type = DT_INT32; *tf_type = DT_INT32;
break; break;
default: default:
return errors::Internal("Invalid TRT type"); return errors::InvalidArgument("Invalid TRT data type");
} }
return Status::OK(); return Status::OK();
} }

View File

@ -103,13 +103,16 @@ class TRTEngineOp : public AsyncOpKernel {
TRTEngineCacheResource* cache_res, TRTEngineCacheResource* cache_res,
AsyncHelper* helper); AsyncHelper* helper);
// Construct a function handle for executing native funcdef graph // Constructs a function handle for the segment of the TRTEngineOp.
// These are the exact same function. StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle(
FunctionLibraryRuntime* lib, const string& device_name,
bool allow_soft_placement = false, size_t num_inputs = 0,
size_t num_outputs = 0);
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib, // Imports the GraphDef for the segment of the TRTEngineOp to
const string& device_name, // segment_graph_def_.
bool allow_soft_placement = false, Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
size_t num_inputs = 0, size_t num_outputs = 0); const string& device_name);
// Executes replaced native segment as function Op. // Executes replaced native segment as function Op.
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
@ -175,12 +178,16 @@ class TRTEngineOp : public AsyncOpKernel {
// Whether to build TensorRT engines at runtime. // Whether to build TensorRT engines at runtime.
bool allow_build_at_runtime_; bool allow_build_at_runtime_;
// Whether to allow soft placement when the graph is executed with native
// TensorFlow.
bool allow_soft_placement_;
// Maximum number of cached engines. // Maximum number of cached engines.
int max_cached_engines_; int max_cached_engines_;
int64 workspace_size_; int64 workspace_size_;
mutex engine_mutex_; mutex engine_mutex_;
FunctionLibraryRuntime::Handle func_handle_; FunctionLibraryRuntime::Handle native_execution_func_handle_;
// The finalized calibrator for inference. // The finalized calibrator for inference.
std::unique_ptr<TRTInt8Calibrator> calibrator_; std::unique_ptr<TRTInt8Calibrator> calibrator_;
@ -260,11 +267,9 @@ static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
return Status::OK(); return Status::OK();
} }
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib, StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle(
const string& device_name, FunctionLibraryRuntime* lib, const string& device_name,
bool allow_soft_placement, bool allow_soft_placement, size_t num_inputs, size_t num_outputs) {
size_t num_inputs,
size_t num_outputs) {
VLOG(1) << "Constructing function handle"; VLOG(1) << "Constructing function handle";
if (lib == nullptr) { if (lib == nullptr) {
return errors::Internal("Context function library is null"); return errors::Internal("Context function library is null");
@ -298,8 +303,20 @@ Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
inst_ops.config_proto.set_allow_soft_placement(true); inst_ops.config_proto.set_allow_soft_placement(true);
} }
} }
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops, FunctionLibraryRuntime::Handle func_handle;
&func_handle_); Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
inst_ops, &func_handle);
if (status.ok()) {
return func_handle;
}
return status;
}
Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
const string& device_name) {
TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle,
ConstructFunctionHandle(lib, device_name));
return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_);
} }
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
@ -335,14 +352,21 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
<< context->device()->name() << context->device()->name()
<< ", thus setting _allow_build_at_runtime=true"; << ", thus setting _allow_build_at_runtime=true";
allow_build_at_runtime_ = true; allow_build_at_runtime_ = true;
} else {
OP_REQUIRES_OK(context, status);
} }
func_handle_ = kInvalidHandle;
status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_);
if (status.code() == tensorflow::error::NOT_FOUND) {
allow_soft_placement_ = true;
} else {
OP_REQUIRES_OK(context, status);
}
native_execution_func_handle_ = kInvalidHandle;
if (!static_engine_) { if (!static_engine_) {
FunctionLibraryRuntime* lib = context->function_library(); OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(),
OP_REQUIRES_OK(context, context->device()->name()));
ConstructFunctionHandle(lib, context->device()->name()));
OP_REQUIRES_OK(
context, FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_));
} }
// TODO(laigd): calibration_data is used in TF v1.x and we keep it only for // TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
// backward compatibility reasons. Remove it once all known users switch to // backward compatibility reasons. Remove it once all known users switch to
@ -411,13 +435,13 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
AsyncHelper* helper) { AsyncHelper* helper) {
std::vector<Tensor> inputs; std::vector<Tensor> inputs;
std::vector<Tensor>* outputs = new std::vector<Tensor>(); std::vector<Tensor>* outputs = new std::vector<Tensor>();
if (func_handle_ == kInvalidHandle) { if (native_execution_func_handle_ == kInvalidHandle) {
OP_REQUIRES_OK_ASYNC( StatusOr<FunctionLibraryRuntime::Handle> status_or_handle =
ctx,
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(), ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
/*allow_soft_placement=*/true, allow_soft_placement_, ctx->num_inputs(),
ctx->num_inputs(), ctx->num_outputs()), ctx->num_outputs());
*helper); OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), *helper);
native_execution_func_handle_ = status_or_handle.ValueOrDie();
} }
auto lib = ctx->function_library(); auto lib = ctx->function_library();
FunctionLibraryRuntime::Options opts; FunctionLibraryRuntime::Options opts;
@ -430,7 +454,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
} }
helper->Ref(); // Increment count for calculating native graph helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment: " << name(); VLOG(1) << "Executing native segment: " << name();
lib->Run(opts, func_handle_, inputs, outputs, lib->Run(opts, native_execution_func_handle_, inputs, outputs,
[this, ctx, outputs, helper](const Status& s) { [this, ctx, outputs, helper](const Status& s) {
core::ScopedUnref sc(helper); core::ScopedUnref sc(helper);
OP_REQUIRES_OK_ASYNC(ctx, s, *helper); OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
@ -854,12 +878,8 @@ StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
return std::pair<EngineContext*, int>(&empty_context, 0); return std::pair<EngineContext*, int>(&empty_context, 0);
} }
if (segment_graph_def_.node().empty()) { if (segment_graph_def_.node().empty()) {
FunctionLibraryRuntime* lib = ctx->function_library(); Status status = ImportSegmentGraphDef(ctx->function_library(),
auto status = ConstructFunctionHandle(lib, ctx->device()->name()); ctx->device()->name());
if (status.ok()) {
status =
FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_);
}
if (!status.ok()) { if (!status.ok()) {
LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for " LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for "
<< name() << " failed. " << name() << " failed. "

View File

@ -91,6 +91,13 @@ class TRTEngineOpTestBase : public OpsTestBase {
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device)); OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
NameAttrList function; NameAttrList function;
function.set_name(StrCat(op_name, "_native_segment")); function.set_name(StrCat(op_name, "_native_segment"));
// We disable allow_soft_placement when executing the native segment of the
// TRTEngineOp for the following reasons:
// OpsTestBase only allow one device in the device manager.
// We need to define the GPU device to test TRTEngineOp.
// When allow_soft_placement is true, the TensorFlow runtime produces an
// error if a CPU device is not defined
// (see ProcessFunctionLibraryRuntime::InstantiateMultiDevice).
TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp") TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp")
.Input(FakeInput(1, dtype)) .Input(FakeInput(1, dtype))
.Attr("input_shapes", {shape}) .Attr("input_shapes", {shape})
@ -105,6 +112,7 @@ class TRTEngineOpTestBase : public OpsTestBase {
.Attr("use_calibration", false) .Attr("use_calibration", false)
.Attr("_use_implicit_batch", use_implicit_batch) .Attr("_use_implicit_batch", use_implicit_batch)
.Attr("_allow_build_at_runtime", allow_build_at_runtime) .Attr("_allow_build_at_runtime", allow_build_at_runtime)
.Attr("_allow_soft_placement", false)
.Attr("OutT", {dtype}) .Attr("OutT", {dtype})
.Finalize(OpsTestBase::node_def())); .Finalize(OpsTestBase::node_def()));
TF_ASSERT_OK(InitOpWithFunctionLibrary()); TF_ASSERT_OK(InitOpWithFunctionLibrary());

View File

@ -52,7 +52,11 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
bool IsEnabled(const ConfigProto& config_proto, bool IsEnabled(const ConfigProto& config_proto,
const Graph& graph) const override { const Graph& graph) const override {
return IsMlirBridgePassEnabled(graph, config_proto); // Do not run the bridge if it's enabled by the graph analysis,
// only run if it's enabled by the user explicitly.
MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(graph, config_proto);
return policy == MlirBridgeRolloutPolicy::kEnabledByUser;
} }
// This should be used as a thin mapper around mlir::ModulePass::runOnModule // This should be used as a thin mapper around mlir::ModulePass::runOnModule

View File

@ -361,6 +361,7 @@ tf_cc_test(
":util", ":util",
":xla_data_proto_cc", ":xla_data_proto_cc",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -234,6 +234,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",

View File

@ -59,7 +59,6 @@ cc_library(
srcs = ["comparators.cc"], srcs = ["comparators.cc"],
hdrs = [ hdrs = [
"comparators.h", "comparators.h",
"//tensorflow/compiler/xla:literal_util",
], ],
deps = [ deps = [
":constants", ":constants",

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -46,6 +47,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla { namespace xla {
@ -1266,7 +1268,8 @@ StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
} }
XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs, XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
@ -1278,15 +1281,17 @@ XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
}); });
} }
XlaOp XlaBuilder::DotGeneral(XlaOp lhs, XlaOp rhs, XlaOp XlaBuilder::DotGeneral(
const DotDimensionNumbers& dimension_numbers, XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
TF_ASSIGN_OR_RETURN(Shape shape, TF_ASSIGN_OR_RETURN(
ShapeInference::InferDotOpShape(*lhs_shape, *rhs_shape, Shape shape,
dimension_numbers)); ShapeInference::InferDotOpShape(
*lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
return DotGeneralInternal(shape, lhs, rhs, dimension_numbers, return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
precision_config); precision_config);
}); });
@ -1353,28 +1358,33 @@ Status XlaBuilder::VerifyConvolution(
XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs, XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides, Padding padding, absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ConvWithGeneralDimensions( return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding, lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()), CreateDefaultConvDimensionNumbers(window_strides.size()),
feature_group_count, batch_group_count, precision_config); feature_group_count, batch_group_count, precision_config,
preferred_element_type);
} }
XlaOp XlaBuilder::ConvWithGeneralPadding( XlaOp XlaBuilder::ConvWithGeneralPadding(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ConvGeneral(lhs, rhs, window_strides, padding, return ConvGeneral(lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()), CreateDefaultConvDimensionNumbers(window_strides.size()),
feature_group_count, batch_group_count, precision_config); feature_group_count, batch_group_count, precision_config,
preferred_element_type);
} }
XlaOp XlaBuilder::ConvWithGeneralDimensions( XlaOp XlaBuilder::ConvWithGeneralDimensions(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
@ -1402,7 +1412,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
MakePadding(base_area_dimensions, window_dimensions, MakePadding(base_area_dimensions, window_dimensions,
window_strides, padding), window_strides, padding),
dimension_numbers, feature_group_count, dimension_numbers, feature_group_count,
batch_group_count, precision_config); batch_group_count, precision_config,
preferred_element_type);
}); });
} }
@ -1411,10 +1422,12 @@ XlaOp XlaBuilder::ConvGeneral(
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
dimension_numbers, feature_group_count, dimension_numbers, feature_group_count,
batch_group_count, precision_config); batch_group_count, precision_config,
preferred_element_type);
} }
XlaOp XlaBuilder::ConvGeneralDilated( XlaOp XlaBuilder::ConvGeneralDilated(
@ -1423,7 +1436,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
@ -1442,10 +1456,11 @@ XlaOp XlaBuilder::ConvGeneralDilated(
ShapeInference::InferWindowFromDimensions( ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding, window_dimensions, window_strides, padding,
lhs_dilation, rhs_dilation)); lhs_dilation, rhs_dilation));
TF_ASSIGN_OR_RETURN(Shape shape, TF_ASSIGN_OR_RETURN(
ShapeInference::InferConvolveShape( Shape shape,
*lhs_shape, *rhs_shape, feature_group_count, ShapeInference::InferConvolveShape(
batch_group_count, window, dimension_numbers)); *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
window, dimension_numbers, preferred_element_type));
return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides, return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
padding, lhs_dilation, rhs_dilation, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, dimension_numbers, feature_group_count,
@ -1459,7 +1474,8 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) { const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
std::vector<int64> window_dimensions( std::vector<int64> window_dimensions(
@ -1472,10 +1488,11 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions( TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, window_dimensions, window_strides,
padding, lhs_dilation, rhs_dilation)); padding, lhs_dilation, rhs_dilation));
TF_ASSIGN_OR_RETURN(Shape shape, TF_ASSIGN_OR_RETURN(
ShapeInference::InferConvolveShape( Shape shape,
*lhs_shape, *rhs_shape, feature_group_count, ShapeInference::InferConvolveShape(
batch_group_count, window, dimension_numbers)); *lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
window, dimension_numbers, preferred_element_type));
HloInstructionProto instr; HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto(); *instr.mutable_shape() = shape.ToProto();
@ -1499,14 +1516,15 @@ XlaOp XlaBuilder::DynamicConvInputGrad(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) { const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
HloInstructionProto instr, HloInstructionProto instr,
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation, DynamicConvInstruction(
rhs_dilation, dimension_numbers, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
feature_group_count, batch_group_count, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type)); precision_config, padding_type, preferred_element_type));
instr.set_custom_call_target("DynamicConvolutionInputGrad"); instr.set_custom_call_target("DynamicConvolutionInputGrad");
@ -1521,14 +1539,16 @@ XlaOp XlaBuilder::DynamicConvKernelGrad(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) { const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
HloInstructionProto instr, HloInstructionProto instr,
DynamicConvInstruction(activations, gradients, window_strides, padding, DynamicConvInstruction(activations, gradients, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers, lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count, feature_group_count, batch_group_count,
precision_config, padding_type)); precision_config, padding_type,
preferred_element_type));
instr.set_custom_call_target("DynamicConvolutionKernelGrad"); instr.set_custom_call_target("DynamicConvolutionKernelGrad");
// The gradient of kernel has kernel shape and shouldn't have any dynamic // The gradient of kernel has kernel shape and shouldn't have any dynamic
@ -1545,14 +1565,15 @@ XlaOp XlaBuilder::DynamicConvForward(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) { const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
HloInstructionProto instr, HloInstructionProto instr,
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation, DynamicConvInstruction(
rhs_dilation, dimension_numbers, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
feature_group_count, batch_group_count, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type)); precision_config, padding_type, preferred_element_type));
instr.set_custom_call_target("DynamicConvolutionForward"); instr.set_custom_call_target("DynamicConvolutionForward");
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs}); return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
@ -3331,6 +3352,11 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
} }
if (!need_rewrite) { if (!need_rewrite) {
if (opcode == HloOpcode::kCompare) {
CHECK(!instr_proto->comparison_type().empty());
new_instr->set_comparison_type(
ComparisonTypeToString(Comparison::DefaultComparisonType(PRED)));
}
*new_instr->mutable_name() = *new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, id); GetFullName(instr_proto->opcode(), kNameSeparator, id);
return Status::OK(); return Status::OK();
@ -3990,11 +4016,26 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq); return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
} }
static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection comparison_direction) {
auto b = lhs.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
auto operand_element_type = operand_shape.element_type();
auto compare_type =
primitive_util::IsFloatingPointType(operand_element_type)
? Comparison::Type::kFloatTotalOrder
: Comparison::DefaultComparisonType(operand_element_type);
return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
compare_type);
});
}
XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs, XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) { absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder; return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq, ComparisonDirection::kEq);
compare_type);
} }
XlaOp Ne(const XlaOp lhs, const XlaOp rhs, XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
@ -4004,9 +4045,8 @@ XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs, XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) { absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder; return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe, ComparisonDirection::kNe);
compare_type);
} }
XlaOp Ge(const XlaOp lhs, const XlaOp rhs, XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
@ -4016,9 +4056,8 @@ XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs, XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) { absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder; return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe, ComparisonDirection::kGe);
compare_type);
} }
XlaOp Gt(const XlaOp lhs, const XlaOp rhs, XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
@ -4028,9 +4067,8 @@ XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs, XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) { absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder; return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt, ComparisonDirection::kGt);
compare_type);
} }
XlaOp Le(const XlaOp lhs, const XlaOp rhs, XlaOp Le(const XlaOp lhs, const XlaOp rhs,
@ -4040,10 +4078,10 @@ XlaOp Le(const XlaOp lhs, const XlaOp rhs,
XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs, XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) { absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder; return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe, ComparisonDirection::kLe);
compare_type);
} }
XlaOp Lt(const XlaOp lhs, const XlaOp rhs, XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) { absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt); return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
@ -4051,8 +4089,8 @@ XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs, XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) { absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt, return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
Comparison::Type::kFloatTotalOrder); ComparisonDirection::kLt);
} }
XlaOp Compare(const XlaOp lhs, const XlaOp rhs, XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
@ -4074,44 +4112,49 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
} }
XlaOp Dot(const XlaOp lhs, const XlaOp rhs, XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
return lhs.builder()->Dot(lhs, rhs, precision_config); absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
} }
XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs, XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
const DotDimensionNumbers& dimension_numbers, const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
precision_config); precision_config, preferred_element_type);
} }
XlaOp Conv(const XlaOp lhs, const XlaOp rhs, XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> window_strides, Padding padding, absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->Conv(lhs, rhs, window_strides, padding, return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
feature_group_count, batch_group_count, feature_group_count, batch_group_count,
precision_config); precision_config, preferred_element_type);
} }
XlaOp ConvWithGeneralPadding(const XlaOp lhs, const XlaOp rhs, XlaOp ConvWithGeneralPadding(
absl::Span<const int64> window_strides, const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvWithGeneralPadding( return lhs.builder()->ConvWithGeneralPadding(
lhs, rhs, window_strides, padding, feature_group_count, batch_group_count, lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
precision_config); precision_config, preferred_element_type);
} }
XlaOp ConvWithGeneralDimensions( XlaOp ConvWithGeneralDimensions(
const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides, const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvWithGeneralDimensions( return lhs.builder()->ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
batch_group_count, precision_config); batch_group_count, precision_config, preferred_element_type);
} }
XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs, XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
@ -4119,10 +4162,11 @@ XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, absl::optional<PrimitiveType> preferred_element_type) {
dimension_numbers, feature_group_count, return lhs.builder()->ConvGeneral(
batch_group_count, precision_config); lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
batch_group_count, precision_config, preferred_element_type);
} }
XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs, XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
@ -4132,26 +4176,27 @@ XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvGeneralDilated( return lhs.builder()->ConvGeneralDilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count, dimension_numbers, feature_group_count, batch_group_count,
precision_config); precision_config, preferred_element_type);
} }
XlaOp DynamicConvInputGrad(XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs, XlaOp DynamicConvInputGrad(
absl::Span<const int64> window_strides, XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const int64> window_strides,
absl::Span<const int64> lhs_dilation, absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> rhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, const PrecisionConfig* precision_config, PaddingType padding_type,
PaddingType padding_type) { absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->DynamicConvInputGrad( return lhs.builder()->DynamicConvInputGrad(
input_sizes, lhs, rhs, window_strides, padding, lhs_dilation, input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type); precision_config, padding_type, preferred_element_type);
} }
XlaOp DynamicConvKernelGrad( XlaOp DynamicConvKernelGrad(
@ -4160,11 +4205,12 @@ XlaOp DynamicConvKernelGrad(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) { const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return activations.builder()->DynamicConvKernelGrad( return activations.builder()->DynamicConvKernelGrad(
activations, gradients, window_strides, padding, lhs_dilation, activations, gradients, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type); precision_config, padding_type, preferred_element_type);
} }
XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs, XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
@ -4175,11 +4221,12 @@ XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, const PrecisionConfig* precision_config,
PaddingType padding_type) { PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->DynamicConvForward( return lhs.builder()->DynamicConvForward(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type); precision_config, padding_type, preferred_element_type);
} }
XlaOp Fft(const XlaOp operand, FftType fft_type, XlaOp Fft(const XlaOp operand, FftType fft_type,

View File

@ -521,56 +521,63 @@ class XlaBuilder {
XlaOp tuple_data, XlaOp tuple_data,
int64 index); int64 index);
XlaOp Dot(XlaOp lhs, XlaOp rhs, XlaOp Dot(
const PrecisionConfig* precision_config = nullptr); XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, XlaOp DotGeneral(
const DotDimensionNumbers& dimension_numbers, XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config = nullptr); const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp Conv(
Padding padding, int64 feature_group_count = 1, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
int64 batch_group_count = 1, Padding padding, int64 feature_group_count = 1,
const PrecisionConfig* precision_config = nullptr); int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp ConvWithGeneralPadding( XlaOp ConvWithGeneralPadding(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count = 1, int64 batch_group_count = 1, int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr); const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp ConvWithGeneralDimensions( XlaOp ConvWithGeneralDimensions(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1, int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr); const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, XlaOp ConvGeneral(
absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1, int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr); const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs, XlaOp ConvGeneralDilated(
absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 feature_group_count = 1, int64 batch_group_count = 1,
int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr,
const PrecisionConfig* precision_config = nullptr); absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs, XlaOp DynamicConvForward(
absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, const PrecisionConfig* precision_config, PaddingType padding_type,
PaddingType padding_type); absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DynamicConvInputGrad( XlaOp DynamicConvInputGrad(
XlaOp input_sizes, XlaOp lhs, XlaOp rhs, XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
@ -580,7 +587,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type); const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DynamicConvKernelGrad( XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients, XlaOp activations, XlaOp gradients,
@ -590,7 +598,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type); const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
StatusOr<HloInstructionProto> DynamicConvInstruction( StatusOr<HloInstructionProto> DynamicConvInstruction(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
@ -599,7 +608,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type); const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
virtual StatusOr<XlaOp> ConvGeneralDilatedInternal( virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
@ -1098,10 +1108,12 @@ class XlaBuilder {
ComparisonDirection direction, ComparisonDirection direction,
Comparison::Type compare_type); Comparison::Type compare_type);
friend XlaOp Dot(XlaOp lhs, XlaOp rhs, friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config); const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
const DotDimensionNumbers& dimension_number, const DotDimensionNumbers& dimension_number,
const PrecisionConfig* precision_config); const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
virtual StatusOr<XlaOp> DotGeneralInternal( virtual StatusOr<XlaOp> DotGeneralInternal(
const Shape& shape, XlaOp lhs, XlaOp rhs, const Shape& shape, XlaOp lhs, XlaOp rhs,
const DotDimensionNumbers& dimension_number, const DotDimensionNumbers& dimension_number,
@ -1109,23 +1121,27 @@ class XlaBuilder {
friend XlaOp Conv(XlaOp lhs, XlaOp rhs, friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides, Padding padding, absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config); const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp ConvWithGeneralPadding( friend XlaOp ConvWithGeneralPadding(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config); const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp ConvWithGeneralDimensions( friend XlaOp ConvWithGeneralDimensions(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_confige); const PrecisionConfig* precision_config,
friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::optional<PrimitiveType> preferred_element_type);
absl::Span<const int64> window_strides, friend XlaOp ConvGeneral(
absl::Span<const std::pair<int64, int64>> padding, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
const ConvolutionDimensionNumbers& dimension_numbers, absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count, int64 batch_group_count, const ConvolutionDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config); int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp DynamicConvForward( friend XlaOp DynamicConvForward(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
@ -1133,7 +1149,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type); const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp DynamicConvKernelGrad( friend XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients, XlaOp activations, XlaOp gradients,
absl::Span<const int64> window_strides, absl::Span<const int64> window_strides,
@ -1142,7 +1159,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type); const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp DynamicConvInputGrad( friend XlaOp DynamicConvInputGrad(
XlaOp input_sizes, XlaOp lhs, XlaOp rhs, XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides, absl::Span<const int64> window_strides,
@ -1151,7 +1169,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type); const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp ConvKernelGrad( friend XlaOp ConvKernelGrad(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
@ -1160,7 +1179,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config); const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp ConvGeneralDilated( friend XlaOp ConvGeneralDilated(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
@ -1169,7 +1189,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config); const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp Fft(XlaOp operand, FftType fft_type, friend XlaOp Fft(XlaOp operand, FftType fft_type,
absl::Span<const int64> fft_length); absl::Span<const int64> fft_length);
friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
@ -1813,28 +1834,31 @@ XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
// Enqueues a dot instruction onto the computation. // Enqueues a dot instruction onto the computation.
XlaOp Dot(XlaOp lhs, XlaOp rhs, XlaOp Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config = nullptr); const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a general dot instruction onto the computation. // Enqueues a general dot instruction onto the computation.
XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, XlaOp DotGeneral(
const DotDimensionNumbers& dimension_numbers, XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config = nullptr); const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a convolution instruction onto the computation, which uses the // Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers. // default convolution dimension numbers.
XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp Conv(
Padding padding, int64 feature_group_count = 1, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
int64 batch_group_count = 1, Padding padding, int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr); const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a convolution instruction onto the computation, with the caller // Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding(). // provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(XlaOp lhs, XlaOp rhs, XlaOp ConvWithGeneralPadding(
absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count = 1, int64 feature_group_count = 1, int64 batch_group_count = 1,
int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr,
const PrecisionConfig* precision_config = nullptr); absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a convolution instruction onto the computation, with the caller // Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration. // provided dimension numbers configuration.
@ -1842,47 +1866,48 @@ XlaOp ConvWithGeneralDimensions(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1, int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr); const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a convolution instruction onto the computation, with the caller // Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers. // provided padding configuration as well as the dimension numbers.
XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides, XlaOp ConvGeneral(
absl::Span<const std::pair<int64, int64>> padding, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
const ConvolutionDimensionNumbers& dimension_numbers, absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count = 1, int64 batch_group_count = 1, const ConvolutionDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config = nullptr); int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a convolution instruction onto the computation, with the caller // Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers. // provided padding configuration, dilation factors and dimension numbers.
XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs, XlaOp ConvGeneralDilated(
absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers,
const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, int64 batch_group_count = 1,
int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr,
int64 batch_group_count = 1, absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
const PrecisionConfig* precision_config = nullptr);
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs, XlaOp DynamicConvForward(
absl::Span<const int64> window_strides, XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
absl::Span<const int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers,
const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count,
int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config, PaddingType padding_type,
const PrecisionConfig* precision_config, absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
PaddingType padding_type);
XlaOp DynamicConvInputGrad(XlaOp input_sizes, XlaOp lhs, XlaOp rhs, XlaOp DynamicConvInputGrad(
absl::Span<const int64> window_strides, XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const int64> window_strides,
absl::Span<const int64> lhs_dilation, absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> rhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, const PrecisionConfig* precision_config, PaddingType padding_type,
PaddingType padding_type); absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DynamicConvKernelGrad( XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides, XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
@ -1890,7 +1915,8 @@ XlaOp DynamicConvKernelGrad(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type); const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues an FFT instruction onto the computation, of the given type and // Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length. // with the given FFT length.

View File

@ -338,8 +338,7 @@ TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) {
/*broadcast_dimensions=*/{0, 1, 2}); /*broadcast_dimensions=*/{0, 1, 2});
auto statusor = BuildHloModule(&b); auto statusor = BuildHloModule(&b);
ASSERT_FALSE(statusor.ok()); ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(), EXPECT_THAT(statusor.status().error_message(), HasSubstr("invalid shape"));
HasSubstr("shape's dimensions must not be < 0"));
} }
TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
@ -1066,6 +1065,56 @@ TEST_F(XlaBuilderTest, DynamicTranspose) {
<< result_shape; << result_shape;
} }
TEST_F(XlaBuilderTest, DotWithPreferredElementType) {
XlaBuilder b(TestName());
Shape p0_shape = ShapeUtil::MakeShape(U8, {2, 3});
Shape p1_shape = ShapeUtil::MakeShape(U16, {3, 2});
auto p0 = Parameter(&b, 0, p0_shape, "p0");
auto p1 = Parameter(&b, 1, p1_shape, "p1");
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(0);
DotGeneral(p0, p1, dnums, /*precision_config=*/nullptr,
/*preferred_element_type=*/U32);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
const Shape& result_shape =
module->entry_computation()->root_instruction()->shape();
ASSERT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {2, 2}), result_shape));
}
TEST_F(XlaBuilderTest, ConvolutionWithPreferredElementType) {
XlaBuilder b(TestName());
Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128});
Shape p1_shape = ShapeUtil::MakeShape(S8, {2, 2, 128, 8});
auto p0 = Parameter(&b, 0, p0_shape, "p0");
auto p1 = Parameter(&b, 1, p1_shape, "p1");
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
dnums.add_input_spatial_dimensions(1);
dnums.add_output_spatial_dimensions(1);
dnums.add_input_spatial_dimensions(2);
dnums.add_output_spatial_dimensions(2);
dnums.set_input_feature_dimension(3);
dnums.set_output_feature_dimension(3);
dnums.add_kernel_spatial_dimensions(0);
dnums.add_kernel_spatial_dimensions(1);
dnums.set_kernel_input_feature_dimension(2);
dnums.set_kernel_output_feature_dimension(3);
ConvWithGeneralDimensions(p0, p1, {1, 1}, Padding::kValid, dnums,
/*feature_group_count=*/1, /*batch_group_count=*/1,
/*precision_config=*/nullptr,
/*preferred_element_type=*/S32);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
const Shape& result_shape =
module->entry_computation()->root_instruction()->shape();
ASSERT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {1, 1, 1, 8}), result_shape));
}
TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
XlaBuilder b(TestName()); XlaBuilder b(TestName());
AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)}); AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});

View File

@ -243,6 +243,7 @@ cc_library(
hdrs = ["nvidia_gpu_device.h"], hdrs = ["nvidia_gpu_device.h"],
deps = [ deps = [
":pjrt_client", ":pjrt_client",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:client_library",

View File

@ -15,12 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" #include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
#include "absl/container/flat_hash_map.h"
#ifdef NCCL_ENABLED #ifdef NCCL_ENABLED
#include "third_party/nccl/nccl.h" #include "third_party/nccl/nccl.h"
#endif // NCCL_ENABLED #endif // NCCL_ENABLED
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/device/device_host_allocator.h" #include "tensorflow/core/common_runtime/device/device_host_allocator.h"
#include "tensorflow/core/common_runtime/device/device_id.h" #include "tensorflow/core/common_runtime/device/device_id.h"
@ -166,56 +169,57 @@ std::unique_ptr<tensorflow::BFCAllocator> GetGpuHostAllocator(
// A table mapping NcclCliqueKeys to ncclUniqueId values encoded as strings. // A table mapping NcclCliqueKeys to ncclUniqueId values encoded as strings.
// In a distributed setup the table of NCCL IDs is kept on the master node // In a distributed setup the table of NCCL IDs is kept on the master node
// (node 0). Currently node 0 is the only node that generates ncclUniqueIds; // (node 0). The node of the first participating device will create the unique
// see the TODO below. // id.
class NcclIdStore { class NcclIdStore {
public: public:
NcclIdStore(int node_id, std::shared_ptr<DistributedRuntimeClient> client) NcclIdStore(int node_id, std::shared_ptr<DistributedRuntimeClient> client,
: node_id_(node_id), client_(std::move(client)) {} absl::flat_hash_map<GlobalDeviceId, int> device_to_node)
: node_id_(node_id),
client_(std::move(client)),
device_to_node_(std::move(device_to_node)) {}
StatusOr<std::string> GetNcclUniqueId(const NcclCliqueKey& key); StatusOr<std::string> GetNcclUniqueId(const NcclCliqueKey& key);
private: private:
const int node_id_; const int node_id_;
const std::shared_ptr<DistributedRuntimeClient> client_; const std::shared_ptr<DistributedRuntimeClient> client_;
const absl::flat_hash_map<GlobalDeviceId, int> device_to_node_;
absl::Mutex mu_; absl::Mutex mu_;
absl::flat_hash_map<std::string, std::string> cache_ ABSL_GUARDED_BY(mu_); absl::flat_hash_map<NcclCliqueKey, std::string> cache_ ABSL_GUARDED_BY(mu_);
}; };
StatusOr<std::string> NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) { StatusOr<std::string> NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) {
std::string key_string = GlobalDeviceIdsToString(key.devices()); // The caller must ensure that threads calling this method concurrently have
// unique keys, otherwise the global key-value store may hold the wrong value.
{ {
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
auto it = cache_.find(key_string); auto it = cache_.find(key);
if (it != cache_.end()) { if (it != cache_.end()) {
return it->second; return it->second;
} }
} }
auto result = [&]() -> StatusOr<std::string> { std::string id_string;
// TODO(phawkins): this will deadlock if node 0 is not involved in the int primary_node_id = device_to_node_.at(key.devices()[0]);
// computation. Add support for computations that only use a subset of if (node_id_ == primary_node_id) {
// replicas.
if (node_id_ == 0) {
#ifdef NCCL_ENABLED #ifdef NCCL_ENABLED
ncclUniqueId id; ncclUniqueId id;
ncclResult_t r = ncclGetUniqueId(&id); ncclResult_t r = ncclGetUniqueId(&id);
TF_RET_CHECK(r == ncclSuccess); TF_RET_CHECK(r == ncclSuccess);
std::string value(id.internal, NCCL_UNIQUE_ID_BYTES); id_string = std::string(id.internal, NCCL_UNIQUE_ID_BYTES);
TF_RETURN_IF_ERROR(client_->KeyValueSet(key_string, value)); TF_RETURN_IF_ERROR(client_->KeyValueSet(key.ToString(), id_string));
return value;
#else #else
return FailedPrecondition("NCCL support was not built into XLA binary."); return FailedPrecondition("NCCL support was not built into XLA binary.");
#endif #endif
} else { } else {
return client_->BlockingKeyValueGet(key_string, absl::Minutes(5)); TF_ASSIGN_OR_RETURN(id_string, client_->BlockingKeyValueGet(
} key.ToString(), absl::Minutes(5)));
}();
if (!result.ok()) {
return result.status();
} }
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
return cache_.emplace(key_string, result.ValueOrDie()).first->second; auto result = cache_.emplace(key, std::move(id_string));
TF_RET_CHECK(result.second) << "Unique ID already in cache.";
return result.first->second;
} }
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices( std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
@ -258,8 +262,11 @@ Status BuildDistributedDevices(
distributed_client->EnumerateDevices(local_topology, &global_topology)); distributed_client->EnumerateDevices(local_topology, &global_topology));
std::vector<GlobalDeviceId> gpu_device_ids(local_device_states.size()); std::vector<GlobalDeviceId> gpu_device_ids(local_device_states.size());
absl::flat_hash_map<GlobalDeviceId, int> device_to_node;
for (const LocalTopologyProto& node : global_topology.nodes()) { for (const LocalTopologyProto& node : global_topology.nodes()) {
for (const DeviceProto& device_proto : node.devices()) { for (const DeviceProto& device_proto : node.devices()) {
GlobalDeviceId global_device_id(device_proto.global_device_id());
device_to_node[global_device_id] = node.node_id();
std::unique_ptr<LocalDeviceState> local_device; std::unique_ptr<LocalDeviceState> local_device;
if (node.node_id() == node_id) { if (node.node_id() == node_id) {
TF_RET_CHECK(device_proto.local_device_ordinal() >= 0 && TF_RET_CHECK(device_proto.local_device_ordinal() >= 0 &&
@ -269,8 +276,7 @@ Status BuildDistributedDevices(
nullptr); nullptr);
local_device = local_device =
std::move(local_device_states[device_proto.local_device_ordinal()]); std::move(local_device_states[device_proto.local_device_ordinal()]);
gpu_device_ids[device_proto.local_device_ordinal()] = gpu_device_ids[device_proto.local_device_ordinal()] = global_device_id;
GlobalDeviceId(device_proto.global_device_id());
} }
auto device = absl::make_unique<GpuDevice>( auto device = absl::make_unique<GpuDevice>(
device_proto.global_device_id(), std::move(local_device), device_proto.global_device_id(), std::move(local_device),
@ -283,8 +289,8 @@ Status BuildDistributedDevices(
} }
gpu_executable_run_options->set_gpu_global_device_ids( gpu_executable_run_options->set_gpu_global_device_ids(
std::move(gpu_device_ids)); std::move(gpu_device_ids));
auto nccl_id_store = auto nccl_id_store = std::make_shared<NcclIdStore>(
std::make_shared<NcclIdStore>(node_id, distributed_client); node_id, distributed_client, device_to_node);
gpu_executable_run_options->set_nccl_unique_id_callback( gpu_executable_run_options->set_nccl_unique_id_callback(
[nccl_id_store](const NcclCliqueKey& key) { [nccl_id_store](const NcclCliqueKey& key) {
return nccl_id_store->GetNcclUniqueId(key); return nccl_id_store->GetNcclUniqueId(key);

View File

@ -597,19 +597,25 @@ struct TypeDescriptor<uint16> {
static int Dtype() { return NPY_UINT16; } static int Dtype() { return NPY_UINT16; }
}; };
// We register "int", "long", and "long long" types for portability across
// Linux, where "int" and "long" are the same type, and Windows, where "long"
// and "longlong" are the same type.
template <> template <>
struct TypeDescriptor<uint32> { struct TypeDescriptor<unsigned int> {
typedef uint32 T; typedef unsigned int T;
static int Dtype() { return NPY_UINT32; } static int Dtype() { return NPY_UINT; }
}; };
template <typename Uint64Type> template <>
struct TypeDescriptor< struct TypeDescriptor<unsigned long> { // NOLINT
Uint64Type, typename std::enable_if<std::is_integral<Uint64Type>::value && typedef unsigned long T; // NOLINT
!std::is_signed<Uint64Type>::value && static int Dtype() { return NPY_ULONG; }
sizeof(Uint64Type) == 8>::type> { };
typedef Uint64Type T;
static int Dtype() { return NPY_UINT64; } template <>
struct TypeDescriptor<unsigned long long> { // NOLINT
typedef unsigned long long T; // NOLINT
static int Dtype() { return NPY_ULONGLONG; }
}; };
template <> template <>
@ -625,18 +631,21 @@ struct TypeDescriptor<int16> {
}; };
template <> template <>
struct TypeDescriptor<int32> { struct TypeDescriptor<int> {
typedef int32 T; typedef int T;
static int Dtype() { return NPY_INT32; } static int Dtype() { return NPY_INT; }
}; };
template <typename Int64Type> template <>
struct TypeDescriptor< struct TypeDescriptor<long> { // NOLINT
Int64Type, typename std::enable_if<std::is_integral<Int64Type>::value && typedef long T; // NOLINT
std::is_signed<Int64Type>::value && static int Dtype() { return NPY_LONG; }
sizeof(Int64Type) == 8>::type> { };
typedef Int64Type T;
static int Dtype() { return NPY_INT64; } template <>
struct TypeDescriptor<long long> { // NOLINT
typedef long long T; // NOLINT
static int Dtype() { return NPY_LONGLONG; }
}; };
template <> template <>
@ -1354,7 +1363,15 @@ bool Initialize() {
if (!RegisterBfloat16Cast<uint16>(NPY_UINT16, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<uint16>(NPY_UINT16, /*cast_is_safe=*/false)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<uint32>(NPY_UINT32, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT, /*cast_is_safe=*/false)) {
return false;
}
if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG, // NOLINT
/*cast_is_safe=*/false)) {
return false;
}
if (!RegisterBfloat16Cast<unsigned long long>( // NOLINT
NPY_ULONGLONG, /*cast_is_safe=*/false)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<uint64>(NPY_UINT64, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<uint64>(NPY_UINT64, /*cast_is_safe=*/false)) {
@ -1366,14 +1383,15 @@ bool Initialize() {
if (!RegisterBfloat16Cast<int16>(NPY_INT16, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<int16>(NPY_INT16, /*cast_is_safe=*/false)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<int32>(NPY_INT32, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<int>(NPY_INT, /*cast_is_safe=*/false)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<int64>(NPY_INT64, /*cast_is_safe=*/false)) { if (!RegisterBfloat16Cast<long>(NPY_LONG, // NOLINT
/*cast_is_safe=*/false)) {
return false; return false;
} }
if (!RegisterBfloat16Cast<npy_longlong>(NPY_LONGLONG, if (!RegisterBfloat16Cast<long long>( // NOLINT
/*cast_is_safe=*/false)) { NPY_LONGLONG, /*cast_is_safe=*/false)) {
return false; return false;
} }
// Following the numpy convention. imag part is dropped when converting to // Following the numpy convention. imag part is dropped when converting to

View File

@ -293,7 +293,7 @@ class Bfloat16NumPyTest(parameterized.TestCase):
for dtype in [ for dtype in [
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32, np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
np.uint64 np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
]: ]:
x = np.array([[1, 2, 3]], dtype=dtype) x = np.array([[1, 2, 3]], dtype=dtype)
y = x.astype(bfloat16) y = x.astype(bfloat16)

View File

@ -595,7 +595,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
static const auto* xla_module = static const auto* xla_module =
new py::module(py::module::import("jax.interpreters.xla")); new py::module(py::module::import("jax.interpreters.xla"));
const auto& device_array = xla_module->attr("DeviceArray"); const auto& device_array = xla_module->attr("_DeviceArray");
static const auto* numpy_module = new py::module(py::module::import("numpy")); static const auto* numpy_module = new py::module(py::module::import("numpy"));
const auto& np_array = numpy_module->attr("array"); const auto& np_array = numpy_module->attr("array");
@ -613,7 +613,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
for (py::handle arg : arguments.flat_dynamic_args) { for (py::handle arg : arguments.flat_dynamic_args) {
// We specically only deal with DeviceArray (not ShardedDeviceArray). // We specically only deal with DeviceArray (not ShardedDeviceArray).
// (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored"). // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
if (arg.get_type().is(device_array)) { if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) {
xla::PyBuffer* buffer; xla::PyBuffer* buffer;
if (arg.attr("_device").is_none()) { // Skip non-sticky devices. if (arg.attr("_device").is_none()) { // Skip non-sticky devices.
continue; continue;
@ -653,7 +653,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
xla::PjRtClient* pjrt_client = data_device->client(); xla::PjRtClient* pjrt_client = data_device->client();
for (py::handle arg : arguments.flat_dynamic_args) { for (py::handle arg : arguments.flat_dynamic_args) {
if (arg.get_type().is(device_array)) { if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) {
if (!HasTrivialLazyExpr(arg)) { if (!HasTrivialLazyExpr(arg)) {
return InvalidArgument( return InvalidArgument(
"Non-trivial lazy expression not supported in C++. " "Non-trivial lazy expression not supported in C++. "

View File

@ -108,7 +108,8 @@ void BuildOpsSubmodule(py::module* m) {
py::arg("lhs_dilation"), py::arg("rhs_dilation"), py::arg("lhs_dilation"), py::arg("rhs_dilation"),
py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
py::arg("batch_group_count") = 1, py::arg("batch_group_count") = 1,
py::arg("precision_config") = nullptr); py::arg("precision_config") = nullptr,
py::arg("preferred_element_type") = absl::nullopt);
ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
py::arg("new_element_type")); py::arg("new_element_type"));
ops.def( ops.def(
@ -136,9 +137,11 @@ void BuildOpsSubmodule(py::module* m) {
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false); py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false);
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
py::arg("precision_config") = nullptr); py::arg("precision_config") = nullptr,
py::arg("preferred_element_type") = absl::nullopt);
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); py::arg("dimension_numbers"), py::arg("precision_config") = nullptr,
py::arg("preferred_element_type") = absl::nullopt);
ops.def("DynamicSlice", ops.def("DynamicSlice",
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>, static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
absl::Span<const int64>)>(&DynamicSlice), absl::Span<const int64>)>(&DynamicSlice),

Some files were not shown because too many files have changed in this diff Show More