Merge remote-tracking branch 'upstream/master' into detection_postprocess
This commit is contained in:
commit
82eeda304b
ISSUES.mdRELEASE.md
tensorflow
BUILD
c/eager
compiler
jit
mlir
hlo
include/mlir-hlo/Dialect/mhlo/IR
lib
tests
lite
BUILDflatbuffer_import.cc
ir
python
quantization
tests
transforms
tensorflow
BUILD
ir
tests
fold-broadcast.mlir
graphdef2mlir
layout_optimization_layout_assignment_to_nhwc.mlirlegalize_hlo.mlirlower_tf.mlirmlir2graphdef
outside_compiled_to_host_launch.mlirtransforms
fold_broadcast.ccfunctional_control_flow_to_regions.cclegalize_hlo.cclower_tf.cclower_tf.tdoutside_compiled_to_host_launch.ccpasses.hregion_control_flow_to_functional.cctpu_extract_outside_compilation.cc
translate
tfr/examples/mnist
tools/kernel_gen
xla
tf2tensorrt
tf2xla
xla
10
ISSUES.md
10
ISSUES.md
@ -1,7 +1,9 @@
|
|||||||
If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance
|
If you open a GitHub Issue, here is our policy:
|
||||||
issue or a feature request or a build issue or a documentation issue (for small
|
|
||||||
doc fixes please send a PR instead). 2. Make sure the Issue Template is filled
|
1. It must be a bug/performance issue or a feature request or a build issue or
|
||||||
out. 3. The issue should be related to the repo it is created in.
|
a documentation issue (for small doc fixes please send a PR instead).
|
||||||
|
1. Make sure the Issue Template is filled out.
|
||||||
|
1. The issue should be related to the repo it is created in.
|
||||||
|
|
||||||
**Here's why we have this policy:** We want to focus on the work that benefits
|
**Here's why we have this policy:** We want to focus on the work that benefits
|
||||||
the whole community, e.g., fixing bugs and adding features. Individual support
|
the whole community, e.g., fixing bugs and adding features. Individual support
|
||||||
|
14
RELEASE.md
14
RELEASE.md
@ -54,6 +54,7 @@
|
|||||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||||
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
||||||
`tf.GradientTape` inside a `tf.function`.
|
`tf.GradientTape` inside a `tf.function`.
|
||||||
|
* Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests.
|
||||||
|
|
||||||
* `tf.summary`:
|
* `tf.summary`:
|
||||||
* New `tf.summary.graph` allows manual write of TensorFlow graph
|
* New `tf.summary.graph` allows manual write of TensorFlow graph
|
||||||
@ -65,6 +66,19 @@
|
|||||||
supported MSVC version to 16.4 (current: 16.8).
|
supported MSVC version to 16.4 (current: 16.8).
|
||||||
* See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
* See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||||
|
|
||||||
|
* TensorRT
|
||||||
|
* Removed the deprecated `session_config` parameter for the TF1-TRT
|
||||||
|
converter `TrtGraphConverter`. Previously, we issued a warning when the
|
||||||
|
value of the parameter is not None.
|
||||||
|
* The TF2-TRT converter `TrtGraphConverterV2` takes an object of class
|
||||||
|
TrtConversionParams as a parameter. Removed three deprecated fields from
|
||||||
|
this class: `rewriter_config_template`, `is_dynamic_op`, and
|
||||||
|
`max_batch_size`. Previously, we issued a warning when the value of
|
||||||
|
`rewriter_config_template` is not None. We issued an error when the
|
||||||
|
value of `is_dynamic_op` is not True. We didn't use the value for
|
||||||
|
`max_batch_size` for building TensorRT engines.
|
||||||
|
* Issue a warning when function get_tensorrt_rewriter_config is used.
|
||||||
|
|
||||||
## Thanks to our Contributors
|
## Thanks to our Contributors
|
||||||
|
|
||||||
This release contains contributions from many people at Google, as well as:
|
This release contains contributions from many people at Google, as well as:
|
||||||
|
@ -72,6 +72,14 @@ config_setting(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Config setting that disables the default logger, only logging
|
||||||
|
# to registered TFLogSinks
|
||||||
|
config_setting(
|
||||||
|
name = "no_default_logger",
|
||||||
|
define_values = {"no_default_logger": "true"},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
# Config setting for determining if we are building for Android.
|
# Config setting for determining if we are building for Android.
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "android",
|
name = "android",
|
||||||
@ -732,6 +740,7 @@ tf_cc_shared_object(
|
|||||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||||
"//tensorflow/c:kernels_hdrs",
|
"//tensorflow/c:kernels_hdrs",
|
||||||
|
"//tensorflow/c:logging",
|
||||||
"//tensorflow/c:ops_hdrs",
|
"//tensorflow/c:ops_hdrs",
|
||||||
"//tensorflow/cc/saved_model:loader_lite_impl",
|
"//tensorflow/cc/saved_model:loader_lite_impl",
|
||||||
"//tensorflow/core/common_runtime:core_cpu_impl",
|
"//tensorflow/core/common_runtime:core_cpu_impl",
|
||||||
|
@ -301,7 +301,7 @@ tf_cuda_cc_test(
|
|||||||
],
|
],
|
||||||
args = ["--heap_check=local"],
|
args = ["--heap_check=local"],
|
||||||
linkstatic = tf_kernel_tests_linkstatic(),
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
tags = tf_cuda_tests_tags(),
|
tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156
|
||||||
deps = [
|
deps = [
|
||||||
":c_api_experimental",
|
":c_api_experimental",
|
||||||
":c_api_unified_internal",
|
":c_api_unified_internal",
|
||||||
@ -469,6 +469,7 @@ tf_cuda_cc_test(
|
|||||||
linkstatic = tf_kernel_tests_linkstatic(),
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
tags = tf_cuda_tests_tags() + [
|
tags = tf_cuda_tests_tags() + [
|
||||||
"nomac",
|
"nomac",
|
||||||
|
"no_cuda_asan", # b/173825513
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":abstract_tensor_handle",
|
":abstract_tensor_handle",
|
||||||
|
@ -61,6 +61,7 @@ limitations under the License.
|
|||||||
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
|
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
|
||||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||||
|
#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed.h"
|
||||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
@ -101,7 +102,14 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
|||||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||||
if (opts->use_tfrt) {
|
if (opts->use_tfrt) {
|
||||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||||
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
tfrt::tf::ContextInterface* tfrt_context =
|
||||||
|
new tfrt::tf::ContextInterface(opts->async);
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
tfrt_context->SetDistributedManager(
|
||||||
|
std::make_unique<tfrt::tf::DistributedManagerContextInterface>(
|
||||||
|
tfrt_context->GetCoreRuntime()->GetHostContext()));
|
||||||
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
return tensorflow::wrap(tfrt_context);
|
||||||
#else
|
#else
|
||||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -226,7 +226,7 @@ void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
|
|||||||
|
|
||||||
// Helper functions which delegate to `AbstractOperation`, update
|
// Helper functions which delegate to `AbstractOperation`, update
|
||||||
// the state of the ForwardOperation and call the tape as appropriate.
|
// the state of the ForwardOperation and call the tape as appropriate.
|
||||||
// These APIs are mainly to faciliate testing and are subject to change.
|
// These APIs are mainly to facilitate testing and are subject to change.
|
||||||
namespace internal {
|
namespace internal {
|
||||||
Status Reset(AbstractOperation* op_, const char* op,
|
Status Reset(AbstractOperation* op_, const char* op,
|
||||||
const char* raw_device_name, ForwardOperation* forward_op_) {
|
const char* raw_device_name, ForwardOperation* forward_op_) {
|
||||||
|
@ -39,7 +39,7 @@ struct XlaAutoJitFlag {
|
|||||||
int32 optimization_level_general;
|
int32 optimization_level_general;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax
|
// Sets the xla_auto_jit_flag based on the given flag string. Supported syntax
|
||||||
// is:
|
// is:
|
||||||
// <number>: sets general and single_gpu setting to the provided number.
|
// <number>: sets general and single_gpu setting to the provided number.
|
||||||
// single-gpu(<number>): sets the single_gpu setting to the provided number.
|
// single-gpu(<number>): sets the single_gpu setting to the provided number.
|
||||||
|
@ -103,7 +103,9 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
|||||||
if (flr->config_proto()) {
|
if (flr->config_proto()) {
|
||||||
config_proto = *flr->config_proto();
|
config_proto = *flr->config_proto();
|
||||||
}
|
}
|
||||||
if (!IsMlirBridgePassEnabled(*fbody->graph, config_proto)) {
|
MlirBridgeRolloutPolicy policy =
|
||||||
|
GetMlirBridgeRolloutPolicy(*fbody->graph, config_proto);
|
||||||
|
if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
|
||||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||||
|
@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
|
|||||||
|
|
||||||
let results = (outs HLO_StaticShapeTensor);
|
let results = (outs HLO_StaticShapeTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
@ -202,9 +202,9 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
|
|||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch,
|
||||||
Arg<I32Buffer, "", [MemWrite]>:$info,
|
Arg<I32Buffer, "", [MemWrite]>:$info,
|
||||||
BoolAttr:$is_upper);
|
BoolAttr:$is_lower);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // LHLO_GPU_OPS
|
#endif // LHLO_GPU_OPS
|
||||||
|
@ -197,10 +197,11 @@ def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// TODO(b/139813999): specify required function signature in a type-safe way.
|
// TODO(b/139813999): specify required function signature in a type-safe way.
|
||||||
def LHLO_ReduceOp: LHLO_Op<"reduce", [
|
//
|
||||||
SameVariadicOperandSize,
|
// The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are
|
||||||
SingleBlockImplicitTerminator<"TerminatorOp">
|
// moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp.
|
||||||
]>, BASE_HLO_ReduceOp {
|
// TODO(timshen): cleanup lmhlo.TerminatorOp.
|
||||||
|
def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
||||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
|
||||||
|
@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
|
MLIRContext* context) {
|
||||||
|
results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
|
||||||
|
context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Case Op
|
// Case Op
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat<
|
|||||||
def ShapeOfDynamicReshape : Pat<
|
def ShapeOfDynamicReshape : Pat<
|
||||||
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
||||||
(replaceWithValue $shape)>;
|
(replaceWithValue $shape)>;
|
||||||
|
|
||||||
|
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||||
|
|
||||||
|
def IdentityBroadcastReshape : Pat<
|
||||||
|
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
|
||||||
|
(replaceWithValue $input),
|
||||||
|
[(HasSameType $input, $op)]>;
|
||||||
|
|
||||||
|
def IdentityBroadcastInDimReshape : Pat<
|
||||||
|
(HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)),
|
||||||
|
(replaceWithValue $input),
|
||||||
|
[(HasSameType $input, $op)]>;
|
||||||
|
@ -61,12 +61,16 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
|||||||
// mapValues always takes a function returning APInt, even when the output
|
// mapValues always takes a function returning APInt, even when the output
|
||||||
// is actually float.
|
// is actually float.
|
||||||
using func_type = llvm::APInt(const llvm::APInt&);
|
using func_type = llvm::APInt(const llvm::APInt&);
|
||||||
|
|
||||||
|
// TODO(hinsu): Correctly handle unsigned element types.
|
||||||
|
bool is_bool = old_type.isInteger(1);
|
||||||
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
|
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
|
||||||
// Int -> Float
|
// Int -> Float
|
||||||
return elements.mapValues(
|
return elements.mapValues(
|
||||||
new_type, llvm::function_ref<func_type>([&newFloatType](
|
new_type, llvm::function_ref<func_type>([&newFloatType, &is_bool](
|
||||||
const llvm::APInt& intVal) {
|
const llvm::APInt& intVal) {
|
||||||
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
|
int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
|
||||||
|
llvm::APFloat newDouble(static_cast<double>(val));
|
||||||
bool loses_info = false;
|
bool loses_info = false;
|
||||||
newDouble.convert(newFloatType.getFloatSemantics(),
|
newDouble.convert(newFloatType.getFloatSemantics(),
|
||||||
llvm::APFloat::rmNearestTiesToEven, &loses_info);
|
llvm::APFloat::rmNearestTiesToEven, &loses_info);
|
||||||
@ -76,9 +80,10 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
|||||||
// new_type is Integer
|
// new_type is Integer
|
||||||
// Int -> Int
|
// Int -> Int
|
||||||
return elements.mapValues(
|
return elements.mapValues(
|
||||||
new_type,
|
new_type, llvm::function_ref<func_type>([&bit_width, &is_bool](
|
||||||
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
|
const llvm::APInt& intVal) {
|
||||||
return llvm::APInt(bit_width, intVal.getSExtValue());
|
int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
|
||||||
|
return llvm::APInt(bit_width, val);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> {
|
|||||||
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
|
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
|
||||||
// CHECK-SAME: ]> : tensor<4x5xi32>
|
// CHECK-SAME: ]> : tensor<4x5xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @identity_broadcast_reshape
|
||||||
|
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||||
|
%0 = "mhlo.broadcast"(%arg0) {
|
||||||
|
broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>
|
||||||
|
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
|
||||||
|
return %1 : tensor<128xf32>
|
||||||
|
// CHECK: return %arg0 : tensor<128xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @identity_broadcast_in_dim_reshape
|
||||||
|
func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||||
|
%0 = "mhlo.broadcast_in_dim"(%arg0) {
|
||||||
|
broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32>
|
||||||
|
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
|
||||||
|
return %1 : tensor<128xf32>
|
||||||
|
// CHECK: return %arg0 : tensor<128xf32>
|
||||||
|
}
|
||||||
|
@ -123,6 +123,17 @@ func @const_int_bf16() -> tensor<bf16> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @const_bool_f32
|
||||||
|
func @const_bool_f32() -> tensor<2xf32> {
|
||||||
|
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
|
||||||
|
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
|
||||||
|
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xf32>
|
||||||
|
// CHECK-NEXT: return [[CST]]
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @const_bf16_int
|
// CHECK-LABEL: func @const_bf16_int
|
||||||
func @const_bf16_int() -> tensor<i16> {
|
func @const_bf16_int() -> tensor<i16> {
|
||||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
|
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
|
||||||
@ -145,8 +156,8 @@ func @const_int_narrowing() -> tensor<i32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @const_int_widening
|
// CHECK-LABEL: func @const_bool_widening
|
||||||
func @const_int_widening() -> tensor<i64> {
|
func @const_bool_widening() -> tensor<i64> {
|
||||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
||||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||||
@ -156,6 +167,17 @@ func @const_int_widening() -> tensor<i64> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @const_int_widening
|
||||||
|
func @const_int_widening() -> tensor<2xi32> {
|
||||||
|
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0, 1]> : tensor<2xi32>
|
||||||
|
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
|
||||||
|
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xi32>
|
||||||
|
// CHECK-NEXT: return [[CST]]
|
||||||
|
return %0 : tensor<2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @const_negative_int_widening
|
// CHECK-LABEL: func @const_negative_int_widening
|
||||||
func @const_negative_int_widening() -> tensor<i64> {
|
func @const_negative_int_widening() -> tensor<i64> {
|
||||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
|
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
|
||||||
|
@ -93,7 +93,7 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
|||||||
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
|
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
|
||||||
%scratch = alloc() : memref<32xi8>
|
%scratch = alloc() : memref<32xi8>
|
||||||
%info = alloc() : memref<32xi32>
|
%info = alloc() : memref<32xi32>
|
||||||
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true }
|
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
|
||||||
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
|
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -704,6 +704,7 @@ cc_library(
|
|||||||
":convert_type",
|
":convert_type",
|
||||||
":flatbuffer_tflite_operator_lib",
|
":flatbuffer_tflite_operator_lib",
|
||||||
":tensorflow_lite",
|
":tensorflow_lite",
|
||||||
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
|
@ -64,6 +64,7 @@ limitations under the License.
|
|||||||
#include "mlir/Translation.h" // from @llvm-project
|
#include "mlir/Translation.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
@ -149,7 +150,7 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
|
|||||||
|
|
||||||
int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
|
int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
|
||||||
is_signed, storage_type.getWidth()) +
|
is_signed, storage_type.getWidth()) +
|
||||||
is_weight_buffer;
|
static_cast<int>(is_weight_buffer);
|
||||||
int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
|
int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
|
||||||
is_signed, storage_type.getWidth());
|
is_signed, storage_type.getWidth());
|
||||||
uint32_t flags =
|
uint32_t flags =
|
||||||
@ -177,12 +178,25 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
|
|||||||
quant_params.zero_point.at(0), storage_min, storage_max);
|
quant_params.zero_point.at(0), storage_min, storage_max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// import float tensor with calibration value into calibrated quantized type.
|
||||||
|
StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
|
||||||
|
Builder builder) {
|
||||||
|
if (tensor.quantization == nullptr) {
|
||||||
|
return errors::InvalidArgument("The tensor is not quantized.");
|
||||||
|
}
|
||||||
|
auto raw_elem_type = ConvertElementType(tensor.type, builder);
|
||||||
|
float min = tensor.quantization->min[0];
|
||||||
|
float max = tensor.quantization->max[0];
|
||||||
|
return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
|
// TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
|
||||||
// make that distinction and don't have to rely on context
|
// make that distinction and don't have to rely on context
|
||||||
// (input to main and constants must have static shape)
|
// (input to main and constants must have static shape)
|
||||||
StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
|
StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
|
||||||
bool shapeless_are_scalars = false,
|
bool shapeless_are_scalars = false,
|
||||||
bool is_constant = false) {
|
bool is_constant = false,
|
||||||
|
bool is_intermediate = false) {
|
||||||
mlir::Type elem_type = ConvertElementType(tensor.type, builder);
|
mlir::Type elem_type = ConvertElementType(tensor.type, builder);
|
||||||
// TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere
|
// TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere
|
||||||
// if it's set
|
// if it's set
|
||||||
@ -191,6 +205,13 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
|
|||||||
GetQuantizedType(tensor, builder, is_constant));
|
GetQuantizedType(tensor, builder, is_constant));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Intermediate tensors with calibration value (but not scale and zero points)
|
||||||
|
// should return calibrated quantized type.
|
||||||
|
if (is_intermediate && tensor.quantization != nullptr &&
|
||||||
|
!IsQuantized(tensor)) {
|
||||||
|
TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
|
||||||
|
}
|
||||||
|
|
||||||
if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
|
if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
|
||||||
return RankedTensorType::get({}, elem_type);
|
return RankedTensorType::get({}, elem_type);
|
||||||
}
|
}
|
||||||
@ -1033,7 +1054,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Intermediate tensors for tfl.lstm are used to carry quantization range
|
// Intermediate tensors for LSTMs are used to carry quantization range
|
||||||
// in their types, so we only need and extract their types.
|
// in their types, so we only need and extract their types.
|
||||||
std::vector<mlir::TensorType> intermediate_types;
|
std::vector<mlir::TensorType> intermediate_types;
|
||||||
intermediate_types.reserve(5);
|
intermediate_types.reserve(5);
|
||||||
@ -1041,7 +1062,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
|
auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
|
||||||
/*shapeless_are_scalars=*/true,
|
/*shapeless_are_scalars=*/true,
|
||||||
/*is_constant=*/true));
|
/*is_constant=*/false,
|
||||||
|
/*is_intermediate=*/true));
|
||||||
intermediate_types.emplace_back(type);
|
intermediate_types.emplace_back(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1135,7 +1157,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
|||||||
|
|
||||||
auto builder = Builder(context);
|
auto builder = Builder(context);
|
||||||
|
|
||||||
|
|
||||||
std::vector<std::string> func_names;
|
std::vector<std::string> func_names;
|
||||||
for (auto& subgraph : model->subgraphs) {
|
for (auto& subgraph : model->subgraphs) {
|
||||||
func_names.push_back(subgraph->name);
|
func_names.push_back(subgraph->name);
|
||||||
|
@ -1978,6 +1978,10 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
|
|
||||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||||
assert(operands.size() == 1);
|
assert(operands.size() == 1);
|
||||||
|
if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
|
||||||
|
return input();
|
||||||
|
}
|
||||||
|
|
||||||
// For now, only supports cast between integer types.
|
// For now, only supports cast between integer types.
|
||||||
auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
|
auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||||
if (!elements_attr) {
|
if (!elements_attr) {
|
||||||
|
@ -483,9 +483,9 @@ value of each element in `x`. For example, if x is an input element and y is
|
|||||||
an output element, this operation computes \\(y = |x|\\).
|
an output element, this operation computes \\(y = |x|\\).
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TFL_FpTensor:$x);
|
let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x);
|
||||||
|
|
||||||
let results = (outs TFL_FpTensor:$y);
|
let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y);
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
@ -587,15 +587,15 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
|
|||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_I32Tensor:$output_shape,
|
TFL_I32Tensor:$output_shape,
|
||||||
TFL_TensorOf<[F32, QI8, QUI8]>:$weights,
|
TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$weights,
|
||||||
TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
|
||||||
TFL_TensorOfOrNone<[F32, QI32]>:$bias,
|
TFL_TensorOfOrNone<[F32, QI32, I64]>:$bias,
|
||||||
TFL_PaddingAttr:$padding,
|
TFL_PaddingAttr:$padding,
|
||||||
Confined<I32Attr, [IntPositive]>:$stride_h,
|
Confined<I32Attr, [IntPositive]>:$stride_h,
|
||||||
Confined<I32Attr, [IntPositive]>:$stride_w
|
Confined<I32Attr, [IntPositive]>:$stride_w
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
|
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
|
||||||
@ -624,7 +624,7 @@ def TFL_AveragePool2DOp:
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
|
||||||
I32Attr:$filter_height,
|
I32Attr:$filter_height,
|
||||||
I32Attr:$filter_width,
|
I32Attr:$filter_width,
|
||||||
TFL_PaddingAttr:$padding,
|
TFL_PaddingAttr:$padding,
|
||||||
@ -633,7 +633,7 @@ def TFL_AveragePool2DOp:
|
|||||||
TFL_AFAttr:$fused_activation_function
|
TFL_AFAttr:$fused_activation_function
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
|
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
let customOption = "Pool2DOptions";
|
let customOption = "Pool2DOptions";
|
||||||
@ -947,7 +947,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
|||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input,
|
TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input,
|
||||||
TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter,
|
TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$filter,
|
||||||
TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
|
TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
|
||||||
|
|
||||||
TFL_AFAttr:$fused_activation_function,
|
TFL_AFAttr:$fused_activation_function,
|
||||||
@ -999,14 +999,14 @@ in the batch dimensions and broadcasting.
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, QI8]>:$x,
|
TFL_TensorOf<[F32, QI8, QI16]>:$x,
|
||||||
TFL_TensorOf<[F32, QI8]>:$y,
|
TFL_TensorOf<[F32, QI8, QI16]>:$y,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, QI8]>:$output
|
TFL_TensorOf<[F32, QI8, QI16]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
@ -1026,7 +1026,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$params,
|
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$params,
|
||||||
TFL_TensorOf<[I32, I64]>:$indices,
|
TFL_TensorOf<[I32, I64]>:$indices,
|
||||||
I32Attr:$axis
|
I32Attr:$axis
|
||||||
);
|
);
|
||||||
@ -1038,7 +1038,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
|||||||
];
|
];
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$output
|
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
@ -1750,12 +1750,12 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input,
|
ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$input,
|
||||||
// Slope of the activation function at x < 0.
|
// Slope of the activation function at x < 0.
|
||||||
F32Attr:$alpha
|
F32Attr:$alpha
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
|
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 0b1;
|
let hasOptions = 0b1;
|
||||||
}
|
}
|
||||||
@ -1977,12 +1977,12 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
|
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
|
||||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
|
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$max
|
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$max
|
||||||
);
|
);
|
||||||
|
|
||||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||||
@ -2005,13 +2005,13 @@ def TFL_MeanOp : TFL_Op<"mean", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$input,
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$input,
|
||||||
TFL_TensorOf<[I32, I64]>:$axis,
|
TFL_TensorOf<[I32, I64]>:$axis,
|
||||||
BoolAttr:$keep_dims
|
BoolAttr:$keep_dims
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$output);
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
let customOption = "ReducerOptions";
|
let customOption = "ReducerOptions";
|
||||||
@ -2090,13 +2090,13 @@ equivalent to setting:
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$input,
|
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||||
TFL_I32OrI64Tensor:$begin,
|
TFL_I32OrI64Tensor:$begin,
|
||||||
TFL_I32OrI64Tensor:$size
|
TFL_I32OrI64Tensor:$size
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$output
|
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
@ -2116,12 +2116,12 @@ def TFL_SumOp: TFL_Op<"sum", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||||
TFL_I32Tensor:$axes,
|
TFL_I32Tensor:$axes,
|
||||||
BoolAttr:$keep_dims
|
BoolAttr:$keep_dims
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
let customOption = "ReducerOptions";
|
let customOption = "ReducerOptions";
|
||||||
@ -2139,13 +2139,13 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||||
TFL_I32Tensor:$axes,
|
TFL_I32Tensor:$axes,
|
||||||
BoolAttr:$keep_dims
|
BoolAttr:$keep_dims
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
let customOption = "ReducerOptions";
|
let customOption = "ReducerOptions";
|
||||||
@ -2163,13 +2163,13 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||||
TFL_I32Tensor:$axes,
|
TFL_I32Tensor:$axes,
|
||||||
BoolAttr:$keep_dims
|
BoolAttr:$keep_dims
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
let customOption = "ReducerOptions";
|
let customOption = "ReducerOptions";
|
||||||
@ -2186,13 +2186,13 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||||
TFL_I32Tensor:$axes,
|
TFL_I32Tensor:$axes,
|
||||||
BoolAttr:$keep_dims
|
BoolAttr:$keep_dims
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
let customOption = "ReducerOptions";
|
let customOption = "ReducerOptions";
|
||||||
@ -2210,12 +2210,12 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
|
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
|
||||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
|
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$min
|
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$min
|
||||||
);
|
);
|
||||||
|
|
||||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||||
@ -2364,10 +2364,10 @@ def TFL_PadOp : TFL_Op<"pad", [
|
|||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||||
TFL_I32OrI64Tensor:$padding);
|
TFL_I32OrI64Tensor:$padding);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
}
|
}
|
||||||
@ -2500,9 +2500,9 @@ def TFL_ReluOp: TFL_Op<"relu", [
|
|||||||
x -> max(0, x)
|
x -> max(0, x)
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
|
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$x);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
|
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$y);
|
||||||
|
|
||||||
// This builder doesn't work with quantized type, so it can only be used by
|
// This builder doesn't work with quantized type, so it can only be used by
|
||||||
// non-quantization tablegen patterns. Currently, it is used by the
|
// non-quantization tablegen patterns. Currently, it is used by the
|
||||||
@ -2828,11 +2828,11 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
|
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||||
F32Attr:$beta
|
F32Attr:$beta
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
|
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
|
||||||
@ -3058,12 +3058,12 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$input,
|
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$input,
|
||||||
TFL_TensorOf<[I32]>:$perm
|
TFL_TensorOf<[I32]>:$perm
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$output
|
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
@ -3330,14 +3330,14 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$input,
|
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$input,
|
||||||
TFL_I32Tensor:$size,
|
TFL_I32Tensor:$size,
|
||||||
BoolAttr:$align_corners,
|
BoolAttr:$align_corners,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
|
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$output
|
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
@ -3830,6 +3830,9 @@ Ba et al. 'Layer Normalization'
|
|||||||
|
|
||||||
// Types of the optional intermediate tensors, which exist for fully
|
// Types of the optional intermediate tensors, which exist for fully
|
||||||
// quantized LSTM op and hold the ranges of the intermediate tensors.
|
// quantized LSTM op and hold the ranges of the intermediate tensors.
|
||||||
|
// The type for intermediate tenssors are be quant.calibrated when imported
|
||||||
|
// to only store calibrated min, max values. The proper quantization spec is
|
||||||
|
// determined while going through quantization passes.
|
||||||
OptionalAttr<TypeAttr>:$input_to_input_intermediate,
|
OptionalAttr<TypeAttr>:$input_to_input_intermediate,
|
||||||
OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
|
OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
|
||||||
OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
|
OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
|
||||||
@ -3945,6 +3948,9 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
|||||||
|
|
||||||
// Types of the optional intermediate tensors, which exist for fully
|
// Types of the optional intermediate tensors, which exist for fully
|
||||||
// quantized op and hold the ranges of the intermediate tensors.
|
// quantized op and hold the ranges of the intermediate tensors.
|
||||||
|
// The type for intermediate tenssors are be quant.calibrated when imported
|
||||||
|
// to only store calibrated min, max values. The proper quantization spec is
|
||||||
|
// determined while going through quantization passes.
|
||||||
OptionalAttr<TypeAttr>:$input_to_input_intermediate,
|
OptionalAttr<TypeAttr>:$input_to_input_intermediate,
|
||||||
OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
|
OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
|
||||||
OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
|
OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
|
||||||
|
@ -89,6 +89,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||||
pass_config.lower_tensor_list_ops = true;
|
pass_config.lower_tensor_list_ops = true;
|
||||||
|
// Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
|
||||||
|
// conversion to an unsupported 16x16 TFL::FullyConnectedOp.
|
||||||
|
if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
|
||||||
|
pass_config.unfold_batch_matmul = false;
|
||||||
|
}
|
||||||
|
|
||||||
return internal::ConvertMLIRToTFLiteFlatBuffer(
|
return internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||||
toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},
|
toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},
|
||||||
|
@ -174,6 +174,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||||
pass_config.lower_tensor_list_ops = true;
|
pass_config.lower_tensor_list_ops = true;
|
||||||
|
// Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
|
||||||
|
// conversion to an unsupported 16x16 TFL::FullyConnectedOp.
|
||||||
|
if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
|
||||||
|
pass_config.unfold_batch_matmul = false;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(b/153507667): Pass the session object when importing logic is removed.
|
// TODO(b/153507667): Pass the session object when importing logic is removed.
|
||||||
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
|
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||||
|
@ -53,10 +53,11 @@ struct QuantizationSpecs {
|
|||||||
bool disable_per_channel = false;
|
bool disable_per_channel = false;
|
||||||
|
|
||||||
// When set to true, the fixed output ranges of the activation ops (tanh,
|
// When set to true, the fixed output ranges of the activation ops (tanh,
|
||||||
// sigmoid, etc.) are not enforced. Then, to quantize these ops, quantization
|
// sigmoid, etc.) and the weight constants are not inferred. Then, to quantize
|
||||||
// emulation ops should be specified after the ops in the input graph. This
|
// these ops, quantization emulation ops should be placed after the ops in the
|
||||||
// flag should be set to false for post-training quantization.
|
// input graph. This flag should be set to false for post-training
|
||||||
bool disable_enforced_fixed_output_range = false;
|
// quantization.
|
||||||
|
bool disable_infer_tensor_range = false;
|
||||||
|
|
||||||
// The node type when the model is exported. Currently this is limited to
|
// The node type when the model is exported. Currently this is limited to
|
||||||
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
|
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
|
||||||
|
@ -100,13 +100,13 @@ class QuantizationDriver {
|
|||||||
explicit QuantizationDriver(FuncOp fn, bool is_signed,
|
explicit QuantizationDriver(FuncOp fn, bool is_signed,
|
||||||
bool disable_per_channel,
|
bool disable_per_channel,
|
||||||
OpQuantSpecGetter op_quant_spec_getter,
|
OpQuantSpecGetter op_quant_spec_getter,
|
||||||
bool enforce_fixed_output_range)
|
bool infer_tensor_range)
|
||||||
: fn_(fn),
|
: fn_(fn),
|
||||||
builder_(fn.getBody()),
|
builder_(fn.getBody()),
|
||||||
is_signed_(is_signed),
|
is_signed_(is_signed),
|
||||||
disable_per_channel_(disable_per_channel),
|
disable_per_channel_(disable_per_channel),
|
||||||
op_quant_spec_getter_(op_quant_spec_getter),
|
op_quant_spec_getter_(op_quant_spec_getter),
|
||||||
enforce_fixed_output_range_(enforce_fixed_output_range) {}
|
infer_tensor_range_(infer_tensor_range) {}
|
||||||
|
|
||||||
// The entry point of the quantization parameters propagation.
|
// The entry point of the quantization parameters propagation.
|
||||||
void Run();
|
void Run();
|
||||||
@ -384,7 +384,9 @@ class QuantizationDriver {
|
|||||||
|
|
||||||
OpQuantSpecGetter op_quant_spec_getter_;
|
OpQuantSpecGetter op_quant_spec_getter_;
|
||||||
|
|
||||||
bool enforce_fixed_output_range_;
|
// Infer output ranges for activation ops and constants. This is usually
|
||||||
|
// required for post-training quantization.
|
||||||
|
bool infer_tensor_range_;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -670,33 +672,43 @@ void QuantizationDriver::PreprocessConstantOps() {
|
|||||||
|
|
||||||
Value value = cst.getResult();
|
Value value = cst.getResult();
|
||||||
builder_.setInsertionPoint(cst);
|
builder_.setInsertionPoint(cst);
|
||||||
for (auto indexed_use : llvm::enumerate(value.getUses())) {
|
|
||||||
auto &use = indexed_use.value();
|
|
||||||
auto spec = GetQuantSpec(use.getOwner());
|
|
||||||
auto biases = spec->biases_params;
|
|
||||||
Operation *user = use.getOwner();
|
|
||||||
int operand_num = use.getOperandNumber();
|
|
||||||
|
|
||||||
|
// The following loop will change the value uses, thus we cache all the uses
|
||||||
|
// needs to be changed.
|
||||||
|
llvm::SmallVector<std::pair<Operation *, int>, 4> uses;
|
||||||
|
for (auto &use : value.getUses()) {
|
||||||
|
uses.push_back({use.getOwner(), use.getOperandNumber()});
|
||||||
|
}
|
||||||
|
for (auto indexed_use : llvm::enumerate(uses)) {
|
||||||
|
Operation *user = indexed_use.value().first;
|
||||||
|
int operand_num = indexed_use.value().second;
|
||||||
|
|
||||||
|
auto spec = GetQuantSpec(user);
|
||||||
|
auto biases = spec->biases_params;
|
||||||
|
|
||||||
|
// The quantization parameters of a `weight` shouldn't be determined by
|
||||||
|
// other values. So any constants which are not bias, an operand of an
|
||||||
|
// op with same scale requirements, and haven't been quantized are
|
||||||
|
// weights.
|
||||||
if (biases.find(operand_num) == biases.end() &&
|
if (biases.find(operand_num) == biases.end() &&
|
||||||
!llvm::dyn_cast<mlir::SameScalesOpInterface>(user) &&
|
!llvm::dyn_cast<mlir::SameScalesOpInterface>(user) &&
|
||||||
!llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
!llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||||
// Needs to scan the content to get the quantization parameters if there
|
// Needs to scan the content of weights to get the quantization
|
||||||
// are no quantization parameters (FakeQuant ops).
|
// parameters if there are no quantization parameters (FakeQuant ops).
|
||||||
// For this case, the weight isn't duplicated.
|
// For this case, the weight will not be duplicated.
|
||||||
weights_.insert(cst);
|
weights_.insert(cst);
|
||||||
auto affine_user =
|
auto affine_user =
|
||||||
llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
|
llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
|
||||||
if (affine_user &&
|
if (affine_user && affine_user.GetAffineOperandIndex() == operand_num &&
|
||||||
affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
|
|
||||||
affine_user.RequiredNarrowRangeAffineOperand()) {
|
affine_user.RequiredNarrowRangeAffineOperand()) {
|
||||||
optimized_weights_.insert(
|
optimized_weights_.insert(
|
||||||
{cst, affine_user.GetQuantizationDimIndex()});
|
{cst, affine_user.GetQuantizationDimIndex()});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// This is a bias, so the quantization parameter isn't determined by the
|
// This is a bias or an operand of an op with same scale requirements,
|
||||||
// local content. Same if the user can have quantization parameter
|
// so the quantization parameter are propagated from or determined by
|
||||||
// propagated from other places.
|
// other values. Duplicate this constant in case it is shared by
|
||||||
// Duplicate this constant in case it is shared by different users.
|
// different users.
|
||||||
if (indexed_use.index() > 0) {
|
if (indexed_use.index() > 0) {
|
||||||
cst = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
|
cst = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
|
||||||
}
|
}
|
||||||
@ -786,12 +798,14 @@ bool QuantizationDriver::PropagateParams() {
|
|||||||
quantized_.insert(op);
|
quantized_.insert(op);
|
||||||
|
|
||||||
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
|
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
|
||||||
// If it isn't a weight or has been quantized, skip.
|
// If the workflow requires inferring ranges from the content
|
||||||
if (!IsWeight(cst) || IsQuantized(op)) continue;
|
// (post-training quantization) and it is weight (filter) and hasn't
|
||||||
|
// been quantized, we infer the quantization parameters from the content.
|
||||||
// The quantization parameters are determined by the content of the
|
if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) {
|
||||||
// constant.
|
// The quantization parameters are determined by the content of the
|
||||||
changed |= SetConstantResultParams(op);
|
// constant.
|
||||||
|
changed |= SetConstantResultParams(op);
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -826,7 +840,9 @@ bool QuantizationDriver::PropagateParams() {
|
|||||||
|
|
||||||
// TODO(fengliuai): make the bit width configurable.
|
// TODO(fengliuai): make the bit width configurable.
|
||||||
auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
|
auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
|
||||||
if (restricted && enforce_fixed_output_range_) {
|
if (restricted && infer_tensor_range_) {
|
||||||
|
// Infer ranges from the activation ops. This is usually required for
|
||||||
|
// the post-training quantization workflow.
|
||||||
// TODO(fengliuai): different result can have different fixed range.
|
// TODO(fengliuai): different result can have different fixed range.
|
||||||
auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
|
auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
|
||||||
for (auto i = 0; i < op->getNumResults(); ++i) {
|
for (auto i = 0; i < op->getNumResults(); ++i) {
|
||||||
@ -903,9 +919,9 @@ void QuantizationDriver::Run() {
|
|||||||
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
||||||
bool disable_per_channel,
|
bool disable_per_channel,
|
||||||
OpQuantSpecGetter op_quant_spec_getter,
|
OpQuantSpecGetter op_quant_spec_getter,
|
||||||
bool post_training_quantization) {
|
bool infer_tensor_ranges) {
|
||||||
QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
|
QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
|
||||||
post_training_quantization)
|
infer_tensor_ranges)
|
||||||
.Run();
|
.Run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -490,13 +490,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
|
|||||||
// and the propagation results are materialized by inserting pairs of quantize
|
// and the propagation results are materialized by inserting pairs of quantize
|
||||||
// and dequantize ops to this function. Set `disable_per_channel` to true to not
|
// and dequantize ops to this function. Set `disable_per_channel` to true to not
|
||||||
// use per channel quantization even the op supports it.
|
// use per channel quantization even the op supports it.
|
||||||
// Setting `enforce_fixed_output_range` to true, to infer quantization
|
// Setting `infer_tensor_range` to true, to infer quantization parameters from
|
||||||
// parameters from the fixed output range ops. This is only used for
|
// the activation ops and weight constants. This is only used for post-training
|
||||||
// post-training quantization.
|
// quantization.
|
||||||
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
||||||
bool disable_per_channel,
|
bool disable_per_channel,
|
||||||
OpQuantSpecGetter op_quant_spec_getter,
|
OpQuantSpecGetter op_quant_spec_getter,
|
||||||
bool enforce_fixed_output_range);
|
bool infer_tensor_ranges);
|
||||||
|
|
||||||
// The function might contain more stats ops than required, and it will
|
// The function might contain more stats ops than required, and it will
|
||||||
// introduce requantize if the calibration stats have conflicts. This method
|
// introduce requantize if the calibration stats have conflicts. This method
|
||||||
|
@ -638,4 +638,10 @@ func @cast_ui8_to_i32() -> tensor<4xi32> {
|
|||||||
// CHECK: return %[[CST]]
|
// CHECK: return %[[CST]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @cast_identity
|
||||||
|
func @cast_identity(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
|
||||||
|
%0 = "tfl.cast"(%arg0) : (tensor<7xf32>) -> tensor<7xf32>
|
||||||
|
return %0 : tensor<7xf32>
|
||||||
|
// CHECK: return %arg0 : tensor<7xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
335
tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json
Normal file
335
tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01, 5.000000e-01>>>
|
||||||
|
// CHECK: input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00, 4.000000e+00>>>
|
||||||
|
// CHECK: input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01, 1.600000e+01>>>
|
||||||
|
// CHECK: input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01, 3.200000e+01>>>
|
||||||
|
// CHECK: input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00, 1.000000e+00>>>
|
||||||
|
|
||||||
|
{
|
||||||
|
"version": 3,
|
||||||
|
"operator_codes": [
|
||||||
|
{
|
||||||
|
"builtin_code": "UNIDIRECTIONAL_SEQUENCE_LSTM"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"subgraphs": [
|
||||||
|
{
|
||||||
|
"tensors": [
|
||||||
|
{
|
||||||
|
"shape": [1, 5],
|
||||||
|
"name": "input0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2, 5],
|
||||||
|
"buffer": 1,
|
||||||
|
"name": "input2input_weights1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2, 5],
|
||||||
|
"buffer": 2,
|
||||||
|
"name": "input2forget_weights2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2, 5],
|
||||||
|
"buffer": 3,
|
||||||
|
"name": "input2cell_weights3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2, 5],
|
||||||
|
"buffer": 4,
|
||||||
|
"name": "input2output_weights4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2, 4],
|
||||||
|
"buffer": 5,
|
||||||
|
"name": "rec2input_weights5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2, 4],
|
||||||
|
"buffer": 6,
|
||||||
|
"name": "rec2forget_weights6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2, 4],
|
||||||
|
"buffer": 7,
|
||||||
|
"name": "rec2cell_weights7"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2, 4],
|
||||||
|
"buffer": 8,
|
||||||
|
"name": "rec2output_weights8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 9,
|
||||||
|
"name": "cell2input_weights9"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 10,
|
||||||
|
"name": "cell2forget_weights10"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 11,
|
||||||
|
"name": "cell2output_weights11"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 12,
|
||||||
|
"name": "input_gate_bias12"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 13,
|
||||||
|
"name": "forget_gate_bias13"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 14,
|
||||||
|
"name": "cell_gate_bias14"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 15,
|
||||||
|
"name": "output_gate_bias15"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [4, 2],
|
||||||
|
"buffer": 16,
|
||||||
|
"name": "proj_weights16"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [4],
|
||||||
|
"buffer": 17,
|
||||||
|
"name": "proj_bias17"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [1, 4],
|
||||||
|
"name": "input_activation_state18",
|
||||||
|
"is_variable": true,
|
||||||
|
"quantization": {
|
||||||
|
"min": [-0.9],
|
||||||
|
"max": [0.9]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [1, 2],
|
||||||
|
"name": "input_cell_state19",
|
||||||
|
"is_variable": true,
|
||||||
|
"quantization": {
|
||||||
|
"min": [-0.8],
|
||||||
|
"max": [0.8]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 18,
|
||||||
|
"name": "input_norm20"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 19,
|
||||||
|
"name": "forget_norm21"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 20,
|
||||||
|
"name": "cell_norm22"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"buffer": 21,
|
||||||
|
"name": "output_norm23"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [],
|
||||||
|
"name": "output24"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [],
|
||||||
|
"name": "intermediate_0",
|
||||||
|
"is_variable": true,
|
||||||
|
"quantization": {
|
||||||
|
"min": [-32],
|
||||||
|
"max": [32]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [],
|
||||||
|
"name": "intermediate_1",
|
||||||
|
"is_variable": true,
|
||||||
|
"quantization": {
|
||||||
|
"min": [-16],
|
||||||
|
"max": [16]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [],
|
||||||
|
"name": "intermediate_2",
|
||||||
|
"is_variable": true,
|
||||||
|
"quantization": {
|
||||||
|
"min": [-4],
|
||||||
|
"max": [4]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [],
|
||||||
|
"name": "intermediate_3",
|
||||||
|
"is_variable": true,
|
||||||
|
"quantization": {
|
||||||
|
"min": [-1.0],
|
||||||
|
"max": [1.0]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shape": [],
|
||||||
|
"name": "intermediate_4",
|
||||||
|
"is_variable": true,
|
||||||
|
"quantization": {
|
||||||
|
"min": [-0.5],
|
||||||
|
"max": [0.5]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"inputs": [0],
|
||||||
|
"outputs": [24],
|
||||||
|
"operators": [
|
||||||
|
{
|
||||||
|
"inputs": [
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23
|
||||||
|
],
|
||||||
|
"outputs": [24],
|
||||||
|
"intermediates": [
|
||||||
|
25, 26, 27, 28, 29
|
||||||
|
],
|
||||||
|
"builtin_options_type": "UnidirectionalSequenceLSTMOptions",
|
||||||
|
"builtin_options": {
|
||||||
|
"fused_activation_function": "TANH",
|
||||||
|
"cell_clip": 50.0
|
||||||
|
},
|
||||||
|
"mutating_variable_inputs": [
|
||||||
|
false,
|
||||||
|
false, false, false, false,
|
||||||
|
false, false, false, false,
|
||||||
|
false, false, false,
|
||||||
|
false, false, false, false,
|
||||||
|
true, true,
|
||||||
|
false, false, false, false
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"buffers": [
|
||||||
|
{
|
||||||
|
"data": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
36, 167, 168, 63, 0, 140, 72, 191, 120, 20, 147, 62, 20, 152, 196, 190, 121, 98, 82, 187, 95, 128, 213, 61, 189, 3, 138, 63, 54, 103, 13, 62, 46, 224, 66, 63, 157, 204, 180, 191
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
223, 20, 21, 64, 246, 166, 31, 191, 6, 51, 157, 188, 114, 90, 167, 62, 118, 240, 59, 63, 49, 162, 255, 62, 17, 91, 160, 63, 32, 47, 26, 63, 40, 136, 178, 191, 243, 154, 236, 61
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
137, 231, 86, 63, 41, 154, 16, 63, 239, 37, 77, 191, 55, 189, 24, 189, 86, 63, 18, 63, 42, 55, 13, 191, 110, 139, 138, 191, 219, 148, 181, 63, 71, 232, 108, 191, 66, 226, 145, 191
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
245, 179, 225, 190, 51, 202, 176, 189, 132, 47, 53, 191, 155, 25, 50, 191, 197, 130, 240, 191, 98, 125, 45, 62, 243, 70, 83, 62, 85, 155, 139, 63, 113, 239, 11, 192, 35, 251, 139, 62
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
248, 188, 211, 191, 142, 11, 73, 62, 36, 8, 84, 63, 186, 208, 11, 191, 76, 208, 190, 191, 223, 200, 210, 63, 183, 170, 103, 63, 116, 129, 145, 63
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
235, 202, 222, 190, 159, 201, 112, 191, 217, 248, 166, 63, 165, 199, 131, 191, 130, 59, 47, 63, 179, 11, 186, 62, 55, 168, 18, 192, 152, 213, 26, 64
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
245, 123, 138, 62, 213, 106, 231, 59, 211, 218, 250, 62, 25, 157, 134, 63, 147, 22, 164, 63, 25, 221, 139, 62, 1, 230, 247, 62, 210, 185, 142, 63
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
197, 123, 23, 192, 45, 96, 178, 190, 174, 87, 165, 62, 213, 225, 200, 191, 119, 248, 15, 191, 128, 125, 171, 189, 90, 125, 222, 63, 4, 76, 95, 62
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
210, 73, 183, 63, 248, 177, 13, 191
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
78, 251, 212, 191, 169, 29, 147, 63
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
178, 227, 203, 191, 247, 155, 103, 63
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
206, 111, 165, 190, 153, 77, 227, 63
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
255, 114, 132, 191, 253, 202, 140, 191
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
90, 247, 1, 192, 125, 120, 209, 191
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
65, 75, 243, 191, 58, 122, 146, 190
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
40, 135, 20, 63, 109, 50, 220, 191, 56, 241, 189, 63, 65, 12, 92, 63, 61, 14, 162, 62, 157, 138, 81, 63, 125, 61, 191, 61, 102, 231, 20, 63
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
145, 79, 49, 189, 175, 235, 220, 190, 182, 111, 157, 190, 144, 236, 97, 191
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
76, 188, 109, 63, 228, 150, 201, 190
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
6, 146, 66, 191, 122, 127, 100, 191
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
216, 59, 169, 190, 161, 178, 215, 191
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
208, 144, 101, 191, 127, 233, 195, 190
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
@ -20,20 +20,20 @@ func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32
|
|||||||
|
|
||||||
func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
|
func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
|
||||||
%cst = constant unit
|
%cst = constant unit
|
||||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||||
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||||
// CHECK-LABEL: testFullyQuantizedLSTM
|
// CHECK-LABEL: testFullyQuantizedLSTM
|
||||||
// CHECK: %cst = constant unit
|
// CHECK: %cst = constant unit
|
||||||
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18)
|
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18)
|
||||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
|
// CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
|
||||||
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -318,6 +318,15 @@ func @any(%arg0: tensor<2x2xi1>, %arg1: tensor<i32>) -> tensor<i1> {
|
|||||||
// CHECK: "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
|
// CHECK: "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @any_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tensor<?xi1> {
|
||||||
|
%0 = "tf.Any"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi64>) -> tensor<?xi1>
|
||||||
|
return %0 : tensor<?xi1>
|
||||||
|
|
||||||
|
// CHECK-LABEL: any_i64axes
|
||||||
|
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
|
||||||
|
// CHECK: "tfl.reduce_any"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor<?xi1>
|
||||||
|
}
|
||||||
|
|
||||||
func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||||
%0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
%0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
return %0 : tensor<8x16xf32>
|
return %0 : tensor<8x16xf32>
|
||||||
@ -972,6 +981,15 @@ func @sum_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32
|
|||||||
// CHECK: "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
// CHECK: "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @sum_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
|
||||||
|
%0 = "tf.Sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: sum_i64axes
|
||||||
|
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
|
||||||
|
// CHECK: "tfl.sum"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||||
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
@ -988,6 +1006,15 @@ func @reduce_min_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso
|
|||||||
// CHECK: "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
// CHECK: "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @reduce_min_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
|
||||||
|
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: reduce_min_i64axes
|
||||||
|
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
|
||||||
|
// CHECK: "tfl.reduce_min"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||||
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
@ -1004,6 +1031,15 @@ func @reduce_max_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso
|
|||||||
// CHECK: "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
// CHECK: "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @reduce_max_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
|
||||||
|
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: reduce_max_i64axes
|
||||||
|
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
|
||||||
|
// CHECK: "tfl.reduce_max"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @reduce_prod(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
func @reduce_prod(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||||
%0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
%0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
@ -1020,6 +1056,15 @@ func @reduce_prod_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tens
|
|||||||
// CHECK: "tfl.reduce_prod"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
// CHECK: "tfl.reduce_prod"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @reduce_prod_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
|
||||||
|
%0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: reduce_prod_i64axes
|
||||||
|
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
|
||||||
|
// CHECK: "tfl.reduce_prod"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
|
func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
|
||||||
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
@ -1172,10 +1217,10 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
|
|||||||
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
|
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
|
||||||
return %0 : tensor<10x10xf32>
|
return %0 : tensor<10x10xf32>
|
||||||
// CHECK-LABEL: strided_slice_with_constant_attributes
|
// CHECK-LABEL: strided_slice_with_constant_attributes
|
||||||
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32>
|
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32>
|
||||||
// CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32>
|
// CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32>
|
||||||
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32>
|
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32>
|
||||||
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
|
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
|
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
|
||||||
@ -1185,6 +1230,39 @@ func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tenso
|
|||||||
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @strided_slice_with_unranked_input_and_i64_parameters(%arg0: tensor<*xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<*xf32> {
|
||||||
|
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
// CHECK-LABEL: strided_slice_with_unranked_input_and_i64_parameters
|
||||||
|
// CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32>
|
||||||
|
// CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32>
|
||||||
|
// CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32>
|
||||||
|
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<*xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @strided_slice_with_i64_parameters(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<1x2x2x5xf32> {
|
||||||
|
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1x2x2x5xf32>
|
||||||
|
return %0 : tensor<1x2x2x5xf32>
|
||||||
|
// CHECK-LABEL: strided_slice_with_i64_parameters
|
||||||
|
// CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32>
|
||||||
|
// CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32>
|
||||||
|
// CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32>
|
||||||
|
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @strided_slice_with_i64_constant_attributes(%arg0: tensor<10x10x10xf32>) -> tensor<10x10xf32> {
|
||||||
|
%cst = constant dense<-1> : tensor<1xi64>
|
||||||
|
%cst_1 = constant dense<0> : tensor<1xi64>
|
||||||
|
%cst_2 = constant dense<1> : tensor<1xi64>
|
||||||
|
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<10x10xf32>
|
||||||
|
return %0 : tensor<10x10xf32>
|
||||||
|
// CHECK-LABEL: strided_slice_with_i64_constant_attributes
|
||||||
|
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32>
|
||||||
|
// CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32>
|
||||||
|
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32>
|
||||||
|
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||||
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||||
return %0 : tensor<?x3x5xf32>
|
return %0 : tensor<?x3x5xf32>
|
||||||
|
@ -625,6 +625,35 @@ func @QuantizeSharedBiases2(
|
|||||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make sure constants are duplicataed for all users.
|
||||||
|
// CHECK-LABEL: QuantizeSharedConstantsMultipleUsers
|
||||||
|
func @QuantizeSharedConstantsMultipleUsers(
|
||||||
|
%arg0: tensor<32x!quant.uniform<u8:f32, 1.0>>,
|
||||||
|
%arg1: tensor<32x!quant.uniform<u8:f32, 2.0>>,
|
||||||
|
%arg2: tensor<32x!quant.uniform<u8:f32, 3.0>>,
|
||||||
|
%arg3: tensor<32x!quant.uniform<u8:f32, 4.0>>) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) {
|
||||||
|
%cst = constant dense<0.0> : tensor<32xf32>
|
||||||
|
%0 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform<u8:f32, 1.0>>) -> tensor<32xf32>
|
||||||
|
%1 = "tfl.dequantize"(%arg1) : (tensor<32x!quant.uniform<u8:f32, 2.0>>) -> tensor<32xf32>
|
||||||
|
%2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<u8:f32, 3.0>>) -> tensor<32xf32>
|
||||||
|
%3 = "tfl.dequantize"(%arg3) : (tensor<32x!quant.uniform<u8:f32, 4.0>>) -> tensor<32xf32>
|
||||||
|
|
||||||
|
%4 = "tfl.minimum"(%0, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||||
|
%5 = "tfl.minimum"(%1, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||||
|
%6 = "tfl.minimum"(%2, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||||
|
%7 = "tfl.minimum"(%3, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||||
|
return %4, %5, %6, %7 : tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[cst1:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<32xf32>
|
||||||
|
// CHECK: %[[cst2:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 3.000000e+00>>) -> tensor<32xf32>
|
||||||
|
// CHECK: %[[cst3:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 4.000000e+00>>) -> tensor<32xf32>
|
||||||
|
// CHECK: %[[cst4:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<32xf32>
|
||||||
|
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst4]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||||
|
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst1]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||||
|
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst2]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||||
|
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst3]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// Make sure quantization parameters are scanned from weight, but not from bias.
|
// Make sure quantization parameters are scanned from weight, but not from bias.
|
||||||
// CHECK-LABEL: QuantizeWeight
|
// CHECK-LABEL: QuantizeWeight
|
||||||
func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||||
|
@ -551,6 +551,19 @@ func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf
|
|||||||
return %0 : tensor<8x4x16x1xf32>
|
return %0 : tensor<8x4x16x1xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<10x10xf32> {
|
||||||
|
%cst = constant dense<-1> : tensor<1xi32>
|
||||||
|
%cst_1 = constant dense<0> : tensor<1xi32>
|
||||||
|
%cst_2 = constant dense<1> : tensor<1xi32>
|
||||||
|
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
|
||||||
|
return %0 : tensor<10x10xf32>
|
||||||
|
// CHECK-LABEL: strided_slice_with_constant_attributes
|
||||||
|
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32>
|
||||||
|
// CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32>
|
||||||
|
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32>
|
||||||
|
// CHECK-NEXT: "tf.StridedSlice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||||
return %0: tensor<3x3xf32>
|
return %0: tensor<3x3xf32>
|
||||||
|
@ -328,23 +328,28 @@ def LegalizeMean : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2),
|
|||||||
(TFL_MeanOp $arg0, $arg1, $arg2)>;
|
(TFL_MeanOp $arg0, $arg1, $arg2)>;
|
||||||
|
|
||||||
def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2),
|
def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2),
|
||||||
(TFL_SumOp $arg, $axes, $arg2)>;
|
(TFL_SumOp $arg, (CreateTFCastToInt32Op $axes), $arg2)>;
|
||||||
|
|
||||||
// TopK in TFL is always sorted so we ignore that attribute here.
|
// TopK in TFL is always sorted so we ignore that attribute here.
|
||||||
def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted),
|
def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted),
|
||||||
(TFL_TopKV2Op $input, $k)>;
|
(TFL_TopKV2Op $input, $k)>;
|
||||||
|
|
||||||
def LegalizeMin : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2),
|
def LegalizeMin : Pat<
|
||||||
(TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
|
(TF_MinOp $arg0, $axes, BoolAttr:$arg2),
|
||||||
|
(TFL_ReduceMinOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
|
||||||
|
|
||||||
def LegalizeMax : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2),
|
def LegalizeMax : Pat<
|
||||||
(TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
|
(TF_MaxOp $arg0, $axes, BoolAttr:$arg2),
|
||||||
|
(TFL_ReduceMaxOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
|
||||||
|
|
||||||
def LegalizeProd : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2),
|
def LegalizeProd : Pat<
|
||||||
(TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
|
(TF_ProdOp $arg0, $axes, BoolAttr:$arg2),
|
||||||
|
(TFL_ReduceProdOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
|
||||||
|
|
||||||
def LegalizeAny : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims),
|
def LegalizeAny : Pat<
|
||||||
(TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>;
|
(TF_AnyOp $input, $reduction_indices, $keep_dims),
|
||||||
|
(TFL_ReduceAnyOp $input, (CreateTFCastToInt32Op $reduction_indices),
|
||||||
|
$keep_dims)>;
|
||||||
|
|
||||||
def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
|
def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
|
||||||
|
|
||||||
@ -471,3 +476,14 @@ def LegalizeCumsum : Pat<
|
|||||||
def LegalizeReshape : Pat<
|
def LegalizeReshape : Pat<
|
||||||
(TF_ReshapeOp $input, $shape),
|
(TF_ReshapeOp $input, $shape),
|
||||||
(TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>;
|
(TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>;
|
||||||
|
|
||||||
|
def LegalizeStridedSlice : Pat<
|
||||||
|
(TF_StridedSliceOp
|
||||||
|
$input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask,
|
||||||
|
$new_axis_mask, $shrink_axis_mask),
|
||||||
|
(TFL_StridedSliceOp $input,
|
||||||
|
(CreateTFCastToInt32Op $begin), (CreateTFCastToInt32Op $end),
|
||||||
|
(CreateTFCastToInt32Op $strides), (convertIntAttrTo32Bit $begin_mask),
|
||||||
|
(convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
|
||||||
|
(convertIntAttrTo32Bit $new_axis_mask),
|
||||||
|
(convertIntAttrTo32Bit $shrink_axis_mask))>;
|
||||||
|
@ -148,7 +148,6 @@ DECL_CONVERT_OP(MatrixDiagV3);
|
|||||||
DECL_CONVERT_OP(Pack);
|
DECL_CONVERT_OP(Pack);
|
||||||
DECL_CONVERT_OP(Split);
|
DECL_CONVERT_OP(Split);
|
||||||
DECL_CONVERT_OP(SplitV);
|
DECL_CONVERT_OP(SplitV);
|
||||||
DECL_CONVERT_OP(StridedSlice);
|
|
||||||
DECL_CONVERT_OP(Unpack);
|
DECL_CONVERT_OP(Unpack);
|
||||||
DECL_CONVERT_OP(RandomUniform);
|
DECL_CONVERT_OP(RandomUniform);
|
||||||
|
|
||||||
@ -325,81 +324,6 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite(
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
|
|
||||||
Value attribute,
|
|
||||||
ArrayRef<int32_t> padding_val, int* mask) {
|
|
||||||
DenseIntElementsAttr dense_elem_attr;
|
|
||||||
SmallVector<int32_t, 8> padded_val;
|
|
||||||
|
|
||||||
auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!ranked_attr_type ||
|
|
||||||
!matchPattern(attribute, m_Constant(&dense_elem_attr))) {
|
|
||||||
// If the input attribute is neither ranked type nor constant, we
|
|
||||||
// can't do any padding. Instead we just return it.
|
|
||||||
return attribute;
|
|
||||||
}
|
|
||||||
for (const auto& idx : dense_elem_attr.getIntValues()) {
|
|
||||||
padded_val.push_back(idx.getSExtValue());
|
|
||||||
}
|
|
||||||
auto attr_dim_count = ranked_attr_type.getShape()[0];
|
|
||||||
int full_dim_count = padding_val.size();
|
|
||||||
for (int i = attr_dim_count; i < full_dim_count; ++i) {
|
|
||||||
padded_val.push_back(padding_val[i]);
|
|
||||||
if (mask) *mask |= 1 << i;
|
|
||||||
}
|
|
||||||
auto type =
|
|
||||||
RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
|
|
||||||
auto attr = DenseElementsAttr::get<int32_t>(type, padded_val);
|
|
||||||
return rewriter.create<ConstantOp>(op->getLoc(), type, attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
|
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
|
||||||
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
|
|
||||||
auto ranked_input_type =
|
|
||||||
tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!ranked_input_type) {
|
|
||||||
// If input is not a ranked tensor, we can't deduce the padding dimensions
|
|
||||||
// from it, so we just do a plain conversion here.
|
|
||||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
|
||||||
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
|
||||||
tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
|
|
||||||
tf_strided_slice_op.strides(),
|
|
||||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.begin_mask()),
|
|
||||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.end_mask()),
|
|
||||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
|
|
||||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
|
|
||||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
int num_input_dims = ranked_input_type.getRank();
|
|
||||||
// Pad `begin` array with zero values and update the `begin_mask`.
|
|
||||||
SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0);
|
|
||||||
int begin_mask = tf_strided_slice_op.begin_mask();
|
|
||||||
Value padded_begin = PadStridedSliceAttributeArray(
|
|
||||||
op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask);
|
|
||||||
// Pad `end` array with `input_shape` and update the `end_mask`.
|
|
||||||
int end_mask = tf_strided_slice_op.end_mask();
|
|
||||||
auto input_shape = ranked_input_type.getShape();
|
|
||||||
SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end());
|
|
||||||
Value padded_end = PadStridedSliceAttributeArray(
|
|
||||||
op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask);
|
|
||||||
// Pad `strides` array with ones.
|
|
||||||
SmallVector<int32_t, 8> strides_pad_val(num_input_dims, 1);
|
|
||||||
Value padded_strides = PadStridedSliceAttributeArray(
|
|
||||||
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
|
|
||||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
|
||||||
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
|
||||||
padded_begin, padded_end, padded_strides,
|
|
||||||
rewriter.getI32IntegerAttr(begin_mask),
|
|
||||||
rewriter.getI32IntegerAttr(end_mask),
|
|
||||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
|
|
||||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
|
|
||||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult ConvertTFUnpackOp::matchAndRewrite(
|
LogicalResult ConvertTFUnpackOp::matchAndRewrite(
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
||||||
@ -769,8 +693,8 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
|
|||||||
patterns
|
patterns
|
||||||
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
|
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
|
||||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
|
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
|
||||||
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
|
ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
|
||||||
ConvertTFAssertOp, ConvertTFRandomUniformOp>(context);
|
ConvertTFRandomUniformOp>(context);
|
||||||
|
|
||||||
// Ophint python converter converted tf node pattern.
|
// Ophint python converter converted tf node pattern.
|
||||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||||
|
@ -376,14 +376,17 @@ void PrepareQuantizePass::runOnFunction() {
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||||
int bit_width = quant_specs_.GetQuantizationTypeWidth();
|
int bit_width = quant_specs_.GetQuantizationTypeWidth();
|
||||||
bool quantization_aware_training_mode = ContainsQuantizeOps(func);
|
// When this is true, the quantizer will try its best to extract the
|
||||||
// Enforce fixed output range for post-training quantization and
|
// quantization parameters from the op quantization property and constant
|
||||||
// when the model has quantization emulation ops, unless it was disabled
|
// content. This is also set to true when the `quantize_allowlist` and
|
||||||
// explicitly by the flag.
|
// `quantize_signed` test flags are enabled.
|
||||||
bool enforced_output_range =
|
bool eager_quantize = ContainsQuantizeOps(func) ||
|
||||||
(quant_specs_.post_training_quantization ||
|
(!quantize_allowlist.empty() || quantize_signed);
|
||||||
quantization_aware_training_mode) &&
|
// Infer the tensor range for the activation ops and weight constants unless
|
||||||
!quant_specs_.disable_enforced_fixed_output_range;
|
// it is disabled explicitly.
|
||||||
|
bool infer_tensor_range =
|
||||||
|
(quant_specs_.post_training_quantization || eager_quantize) &&
|
||||||
|
!quant_specs_.disable_infer_tensor_range;
|
||||||
if (is_signed) {
|
if (is_signed) {
|
||||||
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||||
// Convert quant stats to int8 quantization parameters.
|
// Convert quant stats to int8 quantization parameters.
|
||||||
@ -403,7 +406,7 @@ void PrepareQuantizePass::runOnFunction() {
|
|||||||
// values (tensors).
|
// values (tensors).
|
||||||
ApplyQuantizationParamsPropagation(
|
ApplyQuantizationParamsPropagation(
|
||||||
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
|
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
|
||||||
GetOpQuantSpec, enforced_output_range);
|
GetOpQuantSpec, infer_tensor_range);
|
||||||
|
|
||||||
ConvertMlirQuantOpsToTFLQuantOps(func);
|
ConvertMlirQuantOpsToTFLQuantOps(func);
|
||||||
}
|
}
|
||||||
|
@ -712,6 +712,23 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
|
||||||
|
SmallVectorImpl<int32_t> &val,
|
||||||
|
SmallVectorImpl<int32_t> &padded_val,
|
||||||
|
ArrayRef<int32_t> padding_val,
|
||||||
|
int *mask) const {
|
||||||
|
for (const auto &idx : dense_elem_attr.getIntValues()) {
|
||||||
|
val.push_back(idx.getSExtValue());
|
||||||
|
padded_val.push_back(idx.getSExtValue());
|
||||||
|
}
|
||||||
|
int attr_dim_count = val.size();
|
||||||
|
int full_dim_count = padding_val.size();
|
||||||
|
for (int i = attr_dim_count; i < full_dim_count; ++i) {
|
||||||
|
padded_val.push_back(padding_val[i]);
|
||||||
|
if (mask) *mask |= 1 << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(Operation *op,
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||||
@ -719,17 +736,102 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
|||||||
// Handle new axis mask.
|
// Handle new axis mask.
|
||||||
if (strided_slice_op.new_axis_mask() != 0) {
|
if (strided_slice_op.new_axis_mask() != 0) {
|
||||||
// We currently don't handle simultaneous shrink_ and new_axis masks.
|
// We currently don't handle simultaneous shrink_ and new_axis masks.
|
||||||
if (strided_slice_op.shrink_axis_mask()) {
|
if (!strided_slice_op.shrink_axis_mask()) {
|
||||||
return failure();
|
return RewriteNewAxisMask(strided_slice_op, rewriter);
|
||||||
}
|
}
|
||||||
return RewriteNewAxisMask(strided_slice_op, rewriter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle ellipsis mask.
|
// Handle ellipsis mask.
|
||||||
if (strided_slice_op.ellipsis_mask() != 0) {
|
if (strided_slice_op.ellipsis_mask() != 0) {
|
||||||
return RewriteEllipsisMask(strided_slice_op, rewriter);
|
return RewriteEllipsisMask(strided_slice_op, rewriter);
|
||||||
}
|
}
|
||||||
return failure();
|
|
||||||
|
auto ranked_input_type =
|
||||||
|
strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!ranked_input_type) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto begin_attr = strided_slice_op.begin();
|
||||||
|
auto end_attr = strided_slice_op.end();
|
||||||
|
auto strides_attr = strided_slice_op.strides();
|
||||||
|
|
||||||
|
auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto strides_attr_type =
|
||||||
|
strides_attr.getType().dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
|
DenseIntElementsAttr begin_elem_attr;
|
||||||
|
DenseIntElementsAttr end_elem_attr;
|
||||||
|
DenseIntElementsAttr strides_elem_attr;
|
||||||
|
|
||||||
|
if (!begin_attr_type ||
|
||||||
|
!matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!strides_attr_type ||
|
||||||
|
!matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int32_t, 4> begin, end, strides;
|
||||||
|
SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
|
||||||
|
|
||||||
|
int num_input_dims = ranked_input_type.getRank();
|
||||||
|
SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
|
||||||
|
auto input_shape = ranked_input_type.getShape();
|
||||||
|
SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
|
||||||
|
SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
|
||||||
|
|
||||||
|
int begin_mask = strided_slice_op.begin_mask();
|
||||||
|
int end_mask = strided_slice_op.end_mask();
|
||||||
|
|
||||||
|
PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
|
||||||
|
padding_begin, &begin_mask);
|
||||||
|
PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
|
||||||
|
&end_mask);
|
||||||
|
PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
|
||||||
|
padding_strides, nullptr);
|
||||||
|
|
||||||
|
if (begin == padded_begin && end == padded_end &&
|
||||||
|
strides == padded_strides &&
|
||||||
|
begin_mask == strided_slice_op.begin_mask() &&
|
||||||
|
end_mask == strided_slice_op.end_mask()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto begin_end_type =
|
||||||
|
RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
|
||||||
|
auto new_begin_attr = rewriter.create<ConstantOp>(
|
||||||
|
op->getLoc(), begin_end_type,
|
||||||
|
DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
|
||||||
|
auto new_end_attr = rewriter.create<ConstantOp>(
|
||||||
|
op->getLoc(), begin_end_type,
|
||||||
|
DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
|
||||||
|
auto strides_type =
|
||||||
|
RankedTensorType::get({static_cast<long>(padded_strides.size())},
|
||||||
|
rewriter.getIntegerType(32));
|
||||||
|
auto new_strides_attr = rewriter.create<ConstantOp>(
|
||||||
|
op->getLoc(), strides_type,
|
||||||
|
DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
|
||||||
|
|
||||||
|
auto attribute_type = rewriter.getIntegerType(64);
|
||||||
|
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
|
||||||
|
op, strided_slice_op.output().getType(), strided_slice_op.input(),
|
||||||
|
new_begin_attr, new_end_attr, new_strides_attr,
|
||||||
|
rewriter.getIntegerAttr(attribute_type, begin_mask),
|
||||||
|
rewriter.getIntegerAttr(attribute_type, end_mask),
|
||||||
|
rewriter.getIntegerAttr(attribute_type,
|
||||||
|
strided_slice_op.ellipsis_mask()),
|
||||||
|
rewriter.getIntegerAttr(attribute_type,
|
||||||
|
strided_slice_op.new_axis_mask()),
|
||||||
|
rewriter.getIntegerAttr(attribute_type,
|
||||||
|
strided_slice_op.shrink_axis_mask()));
|
||||||
|
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -337,6 +337,7 @@ cc_library(
|
|||||||
":tensorflow",
|
":tensorflow",
|
||||||
"//tensorflow/compiler/mlir/hlo",
|
"//tensorflow/compiler/mlir/hlo",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Analysis",
|
"@llvm-project//mlir:Analysis",
|
||||||
"@llvm-project//mlir:Dialect",
|
"@llvm-project//mlir:Dialect",
|
||||||
@ -859,6 +860,7 @@ cc_library(
|
|||||||
"transforms/mark_ops_for_outside_compilation.cc",
|
"transforms/mark_ops_for_outside_compilation.cc",
|
||||||
"transforms/materialize_mlir_passthrough_op.cc",
|
"transforms/materialize_mlir_passthrough_op.cc",
|
||||||
"transforms/optimize.cc",
|
"transforms/optimize.cc",
|
||||||
|
"transforms/outside_compiled_to_host_launch.cc",
|
||||||
"transforms/parallel_execute_to_islands.cc",
|
"transforms/parallel_execute_to_islands.cc",
|
||||||
"transforms/parallelize_embedding_params_ops_pass.cc",
|
"transforms/parallelize_embedding_params_ops_pass.cc",
|
||||||
"transforms/promote_resources_to_args.cc",
|
"transforms/promote_resources_to_args.cc",
|
||||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
|||||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
|
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
|
|
||||||
def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_AbsOp : TF_Op<"Abs", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = "Computes the absolute value of a tensor.";
|
let summary = "Computes the absolute value of a tensor.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1002,13 +1002,13 @@ reverse of SpaceToBatch. See below for a precise description.
|
|||||||
TF_Tensor:$output
|
TF_Tensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let verifier = [{
|
|
||||||
return Verify(*this);
|
|
||||||
}];
|
|
||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>;
|
TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>;
|
||||||
TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
|
||||||
|
let verifier = [{
|
||||||
|
return Verify(*this);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> {
|
def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> {
|
||||||
@ -1486,7 +1486,7 @@ def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
|||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_CeilOp : TF_Op<"Ceil", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = "Returns element-wise smallest integer not less than x.";
|
let summary = "Returns element-wise smallest integer not less than x.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -3502,8 +3502,8 @@ tf.math.equal(x, y) ==> array([True, True])
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
TF_Tensor:$x,
|
||||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
|
TF_Tensor:$y,
|
||||||
|
|
||||||
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
||||||
);
|
);
|
||||||
@ -3843,8 +3843,8 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
|
|||||||
let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation.";
|
let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
F32Tensor:$gradients,
|
TF_Float32Tensor:$gradients,
|
||||||
F32Tensor:$inputs,
|
TF_Float32Tensor:$inputs,
|
||||||
|
|
||||||
DefaultValuedAttr<F32Attr, "-6.0f">:$min,
|
DefaultValuedAttr<F32Attr, "-6.0f">:$min,
|
||||||
DefaultValuedAttr<F32Attr, "6.0f">:$max,
|
DefaultValuedAttr<F32Attr, "6.0f">:$max,
|
||||||
@ -3853,7 +3853,7 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
|
|||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
F32Tensor:$backprops
|
TF_Float32Tensor:$backprops
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3911,19 +3911,19 @@ def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradien
|
|||||||
let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation.";
|
let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
F32Tensor:$gradients,
|
TF_Float32Tensor:$gradients,
|
||||||
F32Tensor:$inputs,
|
TF_Float32Tensor:$inputs,
|
||||||
F32Tensor:$min,
|
TF_Float32Tensor:$min,
|
||||||
F32Tensor:$max,
|
TF_Float32Tensor:$max,
|
||||||
|
|
||||||
DefaultValuedAttr<I64Attr, "8">:$num_bits,
|
DefaultValuedAttr<I64Attr, "8">:$num_bits,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
|
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
F32Tensor:$backprops_wrt_input,
|
TF_Float32Tensor:$backprops_wrt_input,
|
||||||
F32Tensor:$backprop_wrt_min,
|
TF_Float32Tensor:$backprop_wrt_min,
|
||||||
F32Tensor:$backprop_wrt_max
|
TF_Float32Tensor:$backprop_wrt_max
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4026,7 +4026,7 @@ fill([2, 3], 9) ==> [[9, 9, 9]
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_FloorOp : TF_Op<"Floor", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = "Returns element-wise largest integer not greater than x.";
|
let summary = "Returns element-wise largest integer not greater than x.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -4977,13 +4977,13 @@ $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
F32Tensor:$predictions,
|
TF_Float32Tensor:$predictions,
|
||||||
TF_I32OrI64Tensor:$targets,
|
TF_I32OrI64Tensor:$targets,
|
||||||
TF_I32OrI64Tensor:$k
|
TF_I32OrI64Tensor:$k
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
I1Tensor:$precision
|
TF_BoolTensor:$precision
|
||||||
);
|
);
|
||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||||
@ -6912,7 +6912,7 @@ retained with length 1.
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||||
let summary = "Performs max pooling on the input.";
|
let summary = "Performs max pooling on the input.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -6936,6 +6936,9 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInter
|
|||||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||||
|
// TF_LayoutSensitiveInterface:
|
||||||
|
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
||||||
|
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -7852,8 +7855,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
TF_Tensor:$x,
|
||||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
|
TF_Tensor:$y,
|
||||||
|
|
||||||
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
||||||
);
|
);
|
||||||
@ -8031,7 +8034,7 @@ times by rerunning "MakeIterator".
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = "Returns a tensor of ones with the same shape and type as x.";
|
let summary = "Returns a tensor of ones with the same shape and type as x.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -8429,6 +8432,10 @@ def TF_QrOp : TF_Op<"Qr", [NoSideEffect]> {
|
|||||||
Computes the QR decomposition of each inner matrix in `tensor` such that
|
Computes the QR decomposition of each inner matrix in `tensor` such that
|
||||||
`tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
|
`tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
|
||||||
|
|
||||||
|
Currently, the gradient for the QR decomposition is well-defined only when
|
||||||
|
the first `P` columns of the inner matrix are linearly independent, where
|
||||||
|
`P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# a is a tensor.
|
# a is a tensor.
|
||||||
# q is a tensor of orthonormal matrices.
|
# q is a tensor of orthonormal matrices.
|
||||||
@ -9114,7 +9121,7 @@ most one RecvTPUEmbeddingActivations op in the TPU graph.
|
|||||||
TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>;
|
TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
|
def TF_ReluOp : TF_Op<"Relu", [Idempotent, NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
|
||||||
let summary = "Computes rectified linear: `max(features, 0)`.";
|
let summary = "Computes rectified linear: `max(features, 0)`.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -9140,7 +9147,7 @@ array([ 0., 0., -0., 3.], dtype=float32)
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_Relu6Op : TF_Op<"Relu6", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`.";
|
let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -10538,7 +10545,7 @@ bitwise_ops.right_shift(lhs, rhs)
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_RintOp : TF_Op<"Rint", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = "Returns element-wise integer closest to x.";
|
let summary = "Returns element-wise integer closest to x.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -10605,7 +10612,7 @@ roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
|
|||||||
TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>;
|
TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_RoundOp : TF_Op<"Round", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Rounds the values of a tensor to the nearest integer, element-wise.
|
Rounds the values of a tensor to the nearest integer, element-wise.
|
||||||
}];
|
}];
|
||||||
@ -11335,7 +11342,7 @@ Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_SignOp : TF_Op<"Sign", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = "Returns an element-wise indication of the sign of a number.";
|
let summary = "Returns an element-wise indication of the sign of a number.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -12345,14 +12352,14 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> {
|
def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> {
|
||||||
let summary = [{
|
let summary = "Picks the best counter-based RNG algorithm based on device.";
|
||||||
Picks the best counter-based RNG algorithm based on device.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
This op picks the best counter-based RNG algorithm based on device.
|
This op picks the best counter-based RNG algorithm based on device.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TF_Int32Tensor:$alg
|
TF_Int32Tensor:$alg
|
||||||
);
|
);
|
||||||
@ -14088,73 +14095,35 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are
|
|||||||
scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
|
scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
|
||||||
for the existing tensor cannot be re-used, a copy is made and updated.
|
for the existing tensor cannot be re-used, a copy is made and updated.
|
||||||
|
|
||||||
If `indices` contains duplicates, then their updates are accumulated (summed).
|
If `indices` contains duplicates, then we pick the last update for the index.
|
||||||
|
|
||||||
**WARNING**: The order in which updates are applied is nondeterministic, so the
|
If an out of bound index is found on CPU, an error is returned.
|
||||||
output will be nondeterministic if `indices` contains duplicates -- because
|
|
||||||
of some numerical approximation issues, numbers summed in different order
|
**WARNING**: There are some GPU specific semantics for this operation.
|
||||||
may yield different results.
|
- If an out of bound index is found, the index is ignored.
|
||||||
|
- The order in which updates are applied is nondeterministic, so the output
|
||||||
|
will be nondeterministic if `indices` contains duplicates.
|
||||||
|
|
||||||
`indices` is an integer tensor containing indices into a new tensor of shape
|
`indices` is an integer tensor containing indices into a new tensor of shape
|
||||||
`shape`. The last dimension of `indices` can be at most the rank of `shape`:
|
`shape`.
|
||||||
|
|
||||||
indices.shape[-1] <= shape.rank
|
* `indices` must have at least 2 axes: `(num_updates, index_depth)`.
|
||||||
|
* The last axis of `indices` is how deep to index into `tensor` so this index
|
||||||
|
depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim`
|
||||||
|
|
||||||
The last dimension of `indices` corresponds to indices into elements
|
if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements.
|
||||||
(if `indices.shape[-1] = shape.rank`) or slices
|
if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input
|
||||||
(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of
|
`tensor`.
|
||||||
`shape`. `updates` is a tensor with shape
|
|
||||||
|
|
||||||
indices.shape[:-1] + shape[indices.shape[-1]:]
|
Each `update` has a rank of `tensor.rank - indices.shape[-1]`.
|
||||||
|
The overall shape of `updates` is:
|
||||||
|
|
||||||
The simplest form of scatter is to insert individual elements in a tensor by
|
```
|
||||||
index. For example, say we want to insert 4 scattered elements in a rank-1
|
indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
|
||||||
tensor with 8 elements.
|
```
|
||||||
|
|
||||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
For usage examples see the python [tf.tensor_scatter_nd_update](
|
||||||
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd1.png" alt>
|
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
|
||||||
</div>
|
|
||||||
|
|
||||||
In Python, this scatter operation would look like this:
|
|
||||||
|
|
||||||
>>> indices = tf.constant([[4], [3], [1], [7]])
|
|
||||||
>>> updates = tf.constant([9, 10, 11, 12])
|
|
||||||
>>> tensor = tf.ones([8], dtype=tf.int32)
|
|
||||||
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
|
|
||||||
tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32)
|
|
||||||
|
|
||||||
We can also, insert entire slices of a higher rank tensor all at once. For
|
|
||||||
example, if we wanted to insert two slices in the first dimension of a
|
|
||||||
rank-3 tensor with two matrices of new values.
|
|
||||||
|
|
||||||
In Python, this scatter operation would look like this:
|
|
||||||
|
|
||||||
>>> indices = tf.constant([[0], [2]])
|
|
||||||
>>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
|
|
||||||
... [7, 7, 7, 7], [8, 8, 8, 8]],
|
|
||||||
... [[5, 5, 5, 5], [6, 6, 6, 6],
|
|
||||||
... [7, 7, 7, 7], [8, 8, 8, 8]]])
|
|
||||||
>>> tensor = tf.ones([4, 4, 4], dtype=tf.int32)
|
|
||||||
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy())
|
|
||||||
[[[5 5 5 5]
|
|
||||||
[6 6 6 6]
|
|
||||||
[7 7 7 7]
|
|
||||||
[8 8 8 8]]
|
|
||||||
[[1 1 1 1]
|
|
||||||
[1 1 1 1]
|
|
||||||
[1 1 1 1]
|
|
||||||
[1 1 1 1]]
|
|
||||||
[[5 5 5 5]
|
|
||||||
[6 6 6 6]
|
|
||||||
[7 7 7 7]
|
|
||||||
[8 8 8 8]]
|
|
||||||
[[1 1 1 1]
|
|
||||||
[1 1 1 1]
|
|
||||||
[1 1 1 1]
|
|
||||||
[1 1 1 1]]]
|
|
||||||
|
|
||||||
Note that on CPU, if an out of bound index is found, an error is returned.
|
|
||||||
On GPU, if an out of bound index is found, the index is ignored.
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -15080,7 +15049,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
|
Arg<TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
|
||||||
Arg<TF_I32OrI64Tensor, [{Array containing the starting indices of the slices we gather.}]>:$start_indices,
|
Arg<TF_I32OrI64Tensor, [{Array containing the starting indices of the slices we gather.}]>:$start_indices,
|
||||||
Arg<TF_I32OrI64Tensor, [{slice_sizes[i] is the bounds for the slice on dimension i.}]>:$slice_sizes,
|
Arg<TF_I32OrI64Tensor, [{slice_sizes[i] is the bounds for the slice on dimension i.}]>:$slice_sizes,
|
||||||
|
|
||||||
@ -15089,7 +15058,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
|
|||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||||
@ -15457,7 +15426,7 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = "Returns a tensor of zeros with the same shape and type as x.";
|
let summary = "Returns a tensor of zeros with the same shape and type as x.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
@ -684,12 +684,23 @@ body: A function that takes a list of tensors and returns another
|
|||||||
|
|
||||||
FlatSymbolRefAttr:$cond,
|
FlatSymbolRefAttr:$cond,
|
||||||
FlatSymbolRefAttr:$body,
|
FlatSymbolRefAttr:$body,
|
||||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
|
||||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||||
|
|
||||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||||
// op.
|
// op.
|
||||||
BoolAttr:$is_stateless
|
BoolAttr:$is_stateless,
|
||||||
|
|
||||||
|
// In TensorFlow, While has a special behavior where if `output_shapes`
|
||||||
|
// attribute is not empty, those shapes are used in its shape function
|
||||||
|
// as result shapes instead of propagating operand shapes as result shapes.
|
||||||
|
// This allows for different result shapes from operand shapes. While these
|
||||||
|
// shapes are imported and set as a part of the result type, there is no
|
||||||
|
// indicator differentiating between having no output shapes compared to
|
||||||
|
// having all unranked shapes. Thus this attribute is set to determine
|
||||||
|
// which shape function behavior to use for this op, specifically
|
||||||
|
// propagating operand shapes as result shapes when this attribute is not
|
||||||
|
// set, or preserving result shapes as is when this attribute is set.
|
||||||
|
UnitAttr:$shape_invariant
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@ -697,6 +708,7 @@ body: A function that takes a list of tensors and returns another
|
|||||||
);
|
);
|
||||||
|
|
||||||
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
||||||
|
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||||
|
|
||||||
let verifier = [{
|
let verifier = [{
|
||||||
return Verify(*this);
|
return Verify(*this);
|
||||||
@ -752,12 +764,23 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
|
|||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Variadic<AnyTensor>:$input,
|
Variadic<AnyTensor>:$input,
|
||||||
|
|
||||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
|
||||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||||
|
|
||||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||||
// op.
|
// op.
|
||||||
BoolAttr:$is_stateless
|
BoolAttr:$is_stateless,
|
||||||
|
|
||||||
|
// In TensorFlow, While has a special behavior where if `output_shapes`
|
||||||
|
// attribute is not empty, those shapes are used in its shape function
|
||||||
|
// as result shapes instead of propagating operand shapes as result shapes.
|
||||||
|
// This allows for different result shapes from operand shapes. While these
|
||||||
|
// shapes are imported and set as a part of the result type, there is no
|
||||||
|
// indicator differentiating between having no output shapes compared to
|
||||||
|
// having all unranked shapes. Thus this attribute is set to determine
|
||||||
|
// which shape function behavior to use for this op, specifically
|
||||||
|
// propagating operand shapes as result shapes when this attribute is not
|
||||||
|
// set, or preserving result shapes as is when this attribute is set.
|
||||||
|
UnitAttr:$shape_invariant
|
||||||
);
|
);
|
||||||
let results = (outs Variadic<AnyTensor>:$output);
|
let results = (outs Variadic<AnyTensor>:$output);
|
||||||
|
|
||||||
|
@ -2467,6 +2467,33 @@ LogicalResult MaxPoolOp::FoldOperandsPermutation(
|
|||||||
permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
|
permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) {
|
||||||
|
StringRef src_data_format = data_format();
|
||||||
|
|
||||||
|
auto perm = GetDataFormatPermutation(src_data_format, new_data_format);
|
||||||
|
if (perm.empty()) return failure();
|
||||||
|
|
||||||
|
// Update data_format attribute and result types.
|
||||||
|
if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
stridesAttr(ShuffleArrayAttr(strides(), perm));
|
||||||
|
explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||||
|
ksizeAttr(ShuffleArrayAttr(ksize(), perm));
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) {
|
||||||
|
// Keep current data format if no GPUs are available or if explicit placement
|
||||||
|
// does not allow to use GPU for this operation.
|
||||||
|
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
|
||||||
|
return data_format();
|
||||||
|
|
||||||
|
// Defaults to NCHW.
|
||||||
|
return "NCHW";
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MaxPoolGradOp
|
// MaxPoolGradOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -41,3 +41,34 @@ func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32
|
|||||||
// CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32>
|
// CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32>
|
||||||
// CHECK: %[[V1]] : tensor<3x5x7xf32>
|
// CHECK: %[[V1]] : tensor<3x5x7xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @broadcast_eq
|
||||||
|
func @broadcast_eq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> {
|
||||||
|
%cst = constant dense<[5, 7]> : tensor<2xi32>
|
||||||
|
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
|
||||||
|
%1 = "tf.Equal"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
|
||||||
|
return %1 : tensor<5x7xi1>
|
||||||
|
// CHECK: %[[V0:.*]] = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
|
||||||
|
// CHECK: %[[V0]] : tensor<5x7xi1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @broadcast_neq
|
||||||
|
func @broadcast_neq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> {
|
||||||
|
%cst = constant dense<[5, 7]> : tensor<2xi32>
|
||||||
|
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
|
||||||
|
%1 = "tf.NotEqual"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
|
||||||
|
return %1 : tensor<5x7xi1>
|
||||||
|
// CHECK: %[[V0:.*]] = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
|
||||||
|
// CHECK: %[[V0]] : tensor<5x7xi1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @broadcast_both_operand
|
||||||
|
func @broadcast_both_operand(%arg0: tensor<7xf32>, %arg1: tensor<5x1xf32>) -> tensor<5x7xf32> {
|
||||||
|
%cst = constant dense<[5, 7]> : tensor<2xi64>
|
||||||
|
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi64>) -> tensor<5x7xf32>
|
||||||
|
%1 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5x1xf32>, tensor<2xi64>) -> tensor<5x7xf32>
|
||||||
|
%2 = "tf.Add"(%0, %1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
|
||||||
|
return %2 : tensor<5x7xf32>
|
||||||
|
// CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
|
||||||
|
// CHECK: %[[V0]] : tensor<5x7xf32>
|
||||||
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
|
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1,WhileWithOutputShapes:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
|
||||||
|
|
||||||
# Verify that TensorFlow While and StatelessWhile ops are mapped to the
|
# Verify that TensorFlow While and StatelessWhile ops are mapped to the
|
||||||
# composite While op in MLIR with is_stateless attribute set accordingly to
|
# composite While op in MLIR with is_stateless attribute set accordingly to
|
||||||
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile")
|
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile")
|
||||||
# CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile")
|
# CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile")
|
||||||
|
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} shape_invariant{{.*}} -> (tensor<i32>, tensor<*xf32>) loc("WhileWithOutputShapes")
|
||||||
|
|
||||||
node {
|
node {
|
||||||
name: "StatefulWhile"
|
name: "StatefulWhile"
|
||||||
@ -73,6 +74,51 @@ node {
|
|||||||
experimental_debug_info {
|
experimental_debug_info {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
node {
|
||||||
|
name: "WhileWithOutputShapes"
|
||||||
|
op: "While"
|
||||||
|
input: "iter"
|
||||||
|
input: "val"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "body"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "body"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "cond"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "cond"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "output_shapes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
shape {
|
||||||
|
unknown_rank: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
node {
|
node {
|
||||||
name: "main"
|
name: "main"
|
||||||
op: "_Retval"
|
op: "_Retval"
|
||||||
@ -107,6 +153,23 @@ node {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
node {
|
||||||
|
name: "main2"
|
||||||
|
op: "_Retval"
|
||||||
|
input: "WhileWithOutputShapes:1"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "index"
|
||||||
|
value {
|
||||||
|
i: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
node {
|
node {
|
||||||
name: "iter"
|
name: "iter"
|
||||||
op: "Placeholder"
|
op: "Placeholder"
|
||||||
|
@ -82,3 +82,20 @@ func @bias_add_nchw(%arg0: tensor<1x256x150x150xf32>, %arg1: tensor<256xf32>) ->
|
|||||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW", device = ""} : (tensor<1x256x150x150xf32>, tensor<256xf32>) -> tensor<1x256x150x150xf32>
|
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW", device = ""} : (tensor<1x256x150x150xf32>, tensor<256xf32>) -> tensor<1x256x150x150xf32>
|
||||||
return %0 : tensor<1x256x150x150xf32>
|
return %0 : tensor<1x256x150x150xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: maxpool_nchw
|
||||||
|
func @maxpool_nchw(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> {
|
||||||
|
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||||
|
// CHECK: %[[R0:.*]] = "tf.Transpose"(%arg0, %[[CST]])
|
||||||
|
// CHECK: %[[R1:.*]] = "tf.MaxPool"(%[[R0]]) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}
|
||||||
|
// CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||||
|
// CHECK: "tf.Transpose"(%[[R1]], %[[CST_0]])
|
||||||
|
%0 = "tf.MaxPool"(%arg0)
|
||||||
|
{
|
||||||
|
data_format = "NCHW",
|
||||||
|
ksize = [1, 1, 3, 3],
|
||||||
|
padding = "SAME",
|
||||||
|
strides = [1, 1, 2, 2]
|
||||||
|
} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
|
||||||
|
return %0 : tensor<1x64x56x56xf32>
|
||||||
|
}
|
||||||
|
@ -1703,3 +1703,110 @@ func @convert_iota_3d() -> tensor<5x7x9xi32> {
|
|||||||
return %0 : tensor<5x7x9xi32>
|
return %0 : tensor<5x7x9xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convert_avgpool_valid(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
|
||||||
|
// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @convert_avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
|
||||||
|
%0 = mhlo.constant dense<0.0> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<9.0> : tensor<4x7x7x8xf32>
|
||||||
|
%2 = "mhlo.reduce_window"(%arg0, %0) ( {
|
||||||
|
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||||
|
%5 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||||
|
"mhlo.return"(%5) : (tensor<f32>) -> ()
|
||||||
|
}) {
|
||||||
|
base_dilations = dense<1> : tensor<4xi64>,
|
||||||
|
padding = dense<0> : tensor<4x2xi64>,
|
||||||
|
window_dilations = dense<1> : tensor<4xi64>,
|
||||||
|
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
|
||||||
|
%3 = mhlo.divide %2, %1 : tensor<4x7x7x8xf32>
|
||||||
|
return %3 : tensor<4x7x7x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convert_avgpool_valid_rw(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
|
||||||
|
// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @convert_avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
|
||||||
|
%0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
|
||||||
|
%1 = mhlo.constant dense<0.0> : tensor<f32>
|
||||||
|
%2 = "mhlo.reduce_window"(%arg0, %1) ( {
|
||||||
|
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||||
|
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||||
|
"mhlo.return"(%6) : (tensor<f32>) -> ()
|
||||||
|
}) {
|
||||||
|
base_dilations = dense<1> : tensor<4xi64>,
|
||||||
|
padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
|
||||||
|
window_dilations = dense<1> : tensor<4xi64>,
|
||||||
|
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
|
||||||
|
%3 = "mhlo.reduce_window"(%0, %1) ( {
|
||||||
|
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||||
|
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||||
|
"mhlo.return"(%6) : (tensor<f32>) -> ()
|
||||||
|
}) {
|
||||||
|
base_dilations = dense<1> : tensor<4xi64>,
|
||||||
|
padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
|
||||||
|
window_dilations = dense<1> : tensor<4xi64>,
|
||||||
|
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
|
||||||
|
%4 = mhlo.divide %2, %3 : tensor<4x7x7x8xf32>
|
||||||
|
return %4 : tensor<4x7x7x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convert_avgpool_valid_3d(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) {data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]} : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
|
||||||
|
// CHECK: return %[[VAL_1]] : tensor<4x7x7x7x8xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @convert_avgpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
|
||||||
|
%0 = mhlo.constant dense<0.0> : tensor<f32>
|
||||||
|
%1 = mhlo.constant dense<27.0> : tensor<4x7x7x7x8xf32>
|
||||||
|
%2 = "mhlo.reduce_window"(%arg0, %0) ( {
|
||||||
|
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||||
|
%5 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||||
|
"mhlo.return"(%5) : (tensor<f32>) -> ()
|
||||||
|
}) {
|
||||||
|
base_dilations = dense<1> : tensor<5xi64>,
|
||||||
|
padding = dense<0> : tensor<5x2xi64>,
|
||||||
|
window_dilations = dense<1> : tensor<5xi64>,
|
||||||
|
window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<4x16x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x7x8xf32>
|
||||||
|
%3 = mhlo.divide %2, %1 : tensor<4x7x7x7x8xf32>
|
||||||
|
return %3 : tensor<4x7x7x7x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convert_avgpool_same(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
|
||||||
|
// CHECK: return %[[VAL_1]] : tensor<4x8x8x8xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @convert_avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
|
||||||
|
%0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
|
||||||
|
%1 = mhlo.constant dense<0.0> : tensor<f32>
|
||||||
|
%2 = "mhlo.reduce_window"(%arg0, %1) ( {
|
||||||
|
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||||
|
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||||
|
"mhlo.return"(%6) : (tensor<f32>) -> ()
|
||||||
|
}) {
|
||||||
|
base_dilations = dense<1> : tensor<4xi64>,
|
||||||
|
padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
|
||||||
|
window_dilations = dense<1> : tensor<4xi64>,
|
||||||
|
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
|
||||||
|
%3 = "mhlo.reduce_window"(%0, %1) ( {
|
||||||
|
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||||
|
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||||
|
"mhlo.return"(%6) : (tensor<f32>) -> ()
|
||||||
|
}) {
|
||||||
|
base_dilations = dense<1> : tensor<4xi64>,
|
||||||
|
padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
|
||||||
|
window_dilations = dense<1> : tensor<4xi64>,
|
||||||
|
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
|
||||||
|
%4 = mhlo.divide %2, %3 : tensor<4x8x8x8xf32>
|
||||||
|
return %4 : tensor<4x8x8x8xf32>
|
||||||
|
}
|
||||||
|
@ -244,9 +244,9 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
|
|||||||
// CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
|
// CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
|
||||||
// CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
|
// CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
|
||||||
// CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
|
// CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
|
||||||
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1)
|
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.AddV2"([[PADDINGS]]#0, [[PADDINGS]]#1)
|
||||||
// CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
|
// CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
|
||||||
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
|
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.AddV2"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
|
||||||
// CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
|
// CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
|
||||||
// CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1)
|
// CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1)
|
||||||
// CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0)
|
// CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0)
|
||||||
@ -338,10 +338,10 @@ func @fake_quant_with_min_max_args(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|||||||
// CHECK-DAG: [[VAL5:%.+]] = "tf.ClipByValue"(%arg0, [[VAL2]], [[VAL1]])
|
// CHECK-DAG: [[VAL5:%.+]] = "tf.ClipByValue"(%arg0, [[VAL2]], [[VAL1]])
|
||||||
// CHECK-DAG: [[VAL6:%.+]] = "tf.Sub"([[VAL5]], [[VAL2]])
|
// CHECK-DAG: [[VAL6:%.+]] = "tf.Sub"([[VAL5]], [[VAL2]])
|
||||||
// CHECK-DAG: [[VAL7:%.+]] = "tf.Mul"([[VAL6]], [[VAL0]])
|
// CHECK-DAG: [[VAL7:%.+]] = "tf.Mul"([[VAL6]], [[VAL0]])
|
||||||
// CHECK-DAG: [[VAL8:%.+]] = "tf.Add"([[VAL7]], [[VAL4]])
|
// CHECK-DAG: [[VAL8:%.+]] = "tf.AddV2"([[VAL7]], [[VAL4]])
|
||||||
// CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
|
// CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
|
||||||
// CHECK-DAG: [[VAL10:%.+]] = "tf.Mul"([[VAL9]], [[VAL3]])
|
// CHECK-DAG: [[VAL10:%.+]] = "tf.Mul"([[VAL9]], [[VAL3]])
|
||||||
// CHECK-DAG: [[VAL11:%.+]] = "tf.Add"([[VAL10]], [[VAL2]])
|
// CHECK-DAG: [[VAL11:%.+]] = "tf.AddV2"([[VAL10]], [[VAL2]])
|
||||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 1.0 : f32, min = -1.0 : f32, narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 1.0 : f32, min = -1.0 : f32, narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
|
||||||
// CHECK: return [[VAL11]]
|
// CHECK: return [[VAL11]]
|
||||||
@ -361,7 +361,7 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>,
|
|||||||
// CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
|
// CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
|
||||||
// CHECK-DAG: [[VAL10:%.+]] = "tf.Sub"([[VAL8]], [[VAL9]])
|
// CHECK-DAG: [[VAL10:%.+]] = "tf.Sub"([[VAL8]], [[VAL9]])
|
||||||
// CHECK-DAG: [[VAL11:%.+]] = "tf.Less"([[VAL10]], [[VAL3]])
|
// CHECK-DAG: [[VAL11:%.+]] = "tf.Less"([[VAL10]], [[VAL3]])
|
||||||
// CHECK-DAG: [[VAL12:%.+]] = "tf.Add"([[VAL2]], [[VAL9]])
|
// CHECK-DAG: [[VAL12:%.+]] = "tf.AddV2"([[VAL9]], [[VAL2]])
|
||||||
// CHECK-DAG: [[VAL13:%.+]] = "tf.Select"([[VAL11]], [[VAL9]], [[VAL12]])
|
// CHECK-DAG: [[VAL13:%.+]] = "tf.Select"([[VAL11]], [[VAL9]], [[VAL12]])
|
||||||
// CHECK-DAG: [[VAL14:%.+]] = "tf.ClipByValue"([[VAL13]], [[VAL0]], [[VAL1]]) :
|
// CHECK-DAG: [[VAL14:%.+]] = "tf.ClipByValue"([[VAL13]], [[VAL0]], [[VAL1]]) :
|
||||||
// CHECK-DAG: [[VAL15:%.+]] = "tf.Sub"([[VAL0]], [[VAL14]])
|
// CHECK-DAG: [[VAL15:%.+]] = "tf.Sub"([[VAL0]], [[VAL14]])
|
||||||
@ -371,10 +371,10 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>,
|
|||||||
// CHECK-DAG: [[VAL19:%.+]] = "tf.ClipByValue"(%arg0, [[VAL17]], [[VAL18]])
|
// CHECK-DAG: [[VAL19:%.+]] = "tf.ClipByValue"(%arg0, [[VAL17]], [[VAL18]])
|
||||||
// CHECK-DAG: [[VAL20:%.+]] = "tf.Sub"([[VAL19]], [[VAL17]])
|
// CHECK-DAG: [[VAL20:%.+]] = "tf.Sub"([[VAL19]], [[VAL17]])
|
||||||
// CHECK-DAG: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL6]])
|
// CHECK-DAG: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL6]])
|
||||||
// CHECK-DAG: [[VAL22:%.+]] = "tf.Add"([[VAL21]], [[VAL3]])
|
// CHECK-DAG: [[VAL22:%.+]] = "tf.AddV2"([[VAL21]], [[VAL3]])
|
||||||
// CHECK-DAG: [[VAL23:%.+]] = "tf.Floor"([[VAL22]])
|
// CHECK-DAG: [[VAL23:%.+]] = "tf.Floor"([[VAL22]])
|
||||||
// CHECK-DAG: [[VAL24:%.+]] = "tf.Mul"([[VAL23]], [[VAL5]])
|
// CHECK-DAG: [[VAL24:%.+]] = "tf.Mul"([[VAL23]], [[VAL5]])
|
||||||
// CHECK-DAG: [[VAL25:%.+]] = "tf.Add"([[VAL24]], [[VAL17]])
|
// CHECK-DAG: [[VAL25:%.+]] = "tf.AddV2"([[VAL24]], [[VAL17]])
|
||||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
|
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
|
|
||||||
// CHECK: return [[VAL25]]
|
// CHECK: return [[VAL25]]
|
||||||
@ -746,7 +746,7 @@ func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
// CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
|
// CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
|
||||||
// CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
|
// CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
|
||||||
// CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
// CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||||
// CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]])
|
// CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]])
|
||||||
// CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
|
// CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
|
||||||
%0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
%0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
|
||||||
@ -761,7 +761,7 @@ func @round_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|||||||
// CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
|
// CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
|
||||||
// CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
|
// CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
|
||||||
// CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
// CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||||
// CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]])
|
// CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]])
|
||||||
// CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
|
// CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
|
||||||
%0 = "tf.Round"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Round"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||||
|
|
||||||
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
|
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) {
|
||||||
%0:2 = tf_executor.graph {
|
%0:3 = tf_executor.graph {
|
||||||
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
|
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
|
||||||
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
|
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
|
||||||
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32>
|
%outputs_6:2, %control_7 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, shape_invariant} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("WhileWithOutputShapes")
|
||||||
|
tf_executor.fetch %outputs_2#1, %outputs_4#1, %outputs_6#1 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
|
||||||
}
|
}
|
||||||
return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32>
|
return %0#0, %0#1, %0#2 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||||
@ -36,6 +37,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
|
|||||||
// CHECK-NOT: name:
|
// CHECK-NOT: name:
|
||||||
// CHECK: op: "While"
|
// CHECK: op: "While"
|
||||||
// CHECK-NOT: is_stateless
|
// CHECK-NOT: is_stateless
|
||||||
|
// CHECK-NOT: shape_invariant
|
||||||
// CHECK: attr {
|
// CHECK: attr {
|
||||||
// CHECK: key: "output_shapes"
|
// CHECK: key: "output_shapes"
|
||||||
// CHECK: value {
|
// CHECK: value {
|
||||||
@ -54,6 +56,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
|
|||||||
// CHECK-NOT: name:
|
// CHECK-NOT: name:
|
||||||
// CHECK: op: "StatelessWhile"
|
// CHECK: op: "StatelessWhile"
|
||||||
// CHECK-NOT: is_stateless
|
// CHECK-NOT: is_stateless
|
||||||
|
// CHECK-NOT: shape_invariant
|
||||||
// CHECK: attr {
|
// CHECK: attr {
|
||||||
// CHECK: key: "output_shapes"
|
// CHECK: key: "output_shapes"
|
||||||
// CHECK: value {
|
// CHECK: value {
|
||||||
@ -67,3 +70,20 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
|
|||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
|
// CHECK: name: "WhileWithOutputShapes"
|
||||||
|
// CHECK-NOT: name:
|
||||||
|
// CHECK: op: "While"
|
||||||
|
// CHECK-NOT: is_stateless
|
||||||
|
// CHECK-NOT: shape_invariant
|
||||||
|
// CHECK: attr {
|
||||||
|
// CHECK: key: "output_shapes"
|
||||||
|
// CHECK: value {
|
||||||
|
// CHECK: list {
|
||||||
|
// CHECK: shape {
|
||||||
|
// CHECK: dim {
|
||||||
|
// CHECK: size: 5
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
@ -0,0 +1,153 @@
|
|||||||
|
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-outside-compiled-to-host-launch | FILECHECK_OPTS="" FileCheck %s
|
||||||
|
|
||||||
|
// expected-error@+1 {{'module' op bad 'tf.devices' attribute at index 0, not a string}}
|
||||||
|
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = [1]} {
|
||||||
|
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
|
||||||
|
func @invalid_device_attribute() -> tensor<?xi32> {
|
||||||
|
%0 = "tf_device.cluster"() ( {
|
||||||
|
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||||
|
%2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
tf_device.return %2 : tensor<?xi32>
|
||||||
|
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||||
|
return %0 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||||
|
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
|
||||||
|
func @empty_outside_compilation_attribute() -> tensor<?xi32> {
|
||||||
|
%0 = "tf_device.cluster"() ( {
|
||||||
|
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||||
|
// expected-error@+1 {{'tf.B' op requires non empty '_xla_outside_compilation' string attribute}}
|
||||||
|
%2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
tf_device.return %2 : tensor<?xi32>
|
||||||
|
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||||
|
return %0 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||||
|
|
||||||
|
// Tests that TPU cluster with no outside compilation does not generate launch op.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @no_outside_compilation
|
||||||
|
// CHECK-NOT: "tf_device.launch"
|
||||||
|
func @no_outside_compilation() -> tensor<?xi32> {
|
||||||
|
%0 = "tf_device.cluster"() ( {
|
||||||
|
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||||
|
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
tf_device.return %2 : tensor<?xi32>
|
||||||
|
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||||
|
return %0 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Tests the launch wrap of a single outside compiled cluster with no input or output dependencies.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @nodep_single_outside_compilation
|
||||||
|
func @nodep_single_outside_compilation() -> () {
|
||||||
|
// CHECK: "tf.A"
|
||||||
|
// CHECK: "tf_device.launch"
|
||||||
|
// CHECK-NEXT: "tf.B"
|
||||||
|
// CHECK-NOT: _xla_outside_compilation
|
||||||
|
// CHECK-NEXT: tf_device.return
|
||||||
|
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||||
|
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
|
||||||
|
"tf_device.cluster"() ( {
|
||||||
|
"tf.A"() : () -> ()
|
||||||
|
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||||
|
"tf.C"() : () -> ()
|
||||||
|
tf_device.return
|
||||||
|
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests the launch wrap of a single outside compiled cluster with data parallelism.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @single_outside_compilation_with_replicate
|
||||||
|
func @single_outside_compilation_with_replicate(%arg0: tensor<?xi32>) -> () {
|
||||||
|
// CHECK: "tf.A"
|
||||||
|
// CHECK: tf_device.replicate
|
||||||
|
// CHECK-NEXT: "tf_device.cluster"
|
||||||
|
// CHECK-NEXT: "tf.B"
|
||||||
|
// CHECK-NEXT: "tf_device.launch"
|
||||||
|
// CHECK-NEXT: "tf.C"
|
||||||
|
// CHECK-NOT: _xla_outside_compilation
|
||||||
|
// CHECK: tf_device.return
|
||||||
|
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||||
|
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
|
||||||
|
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||||
|
"tf_device.cluster"() ( {
|
||||||
|
"tf.B"() : () -> ()
|
||||||
|
"tf.C"(%ri_0) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||||
|
"tf.D"() : () -> ()
|
||||||
|
tf_device.return
|
||||||
|
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||||
|
tf_device.return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests launch wrap of a single outside compiled cluster with input/output.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @single_outside_compilation_input_output
|
||||||
|
func @single_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||||
|
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||||
|
// CHECK: "tf_device.cluster"
|
||||||
|
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||||
|
// CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||||
|
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]])
|
||||||
|
// CHECK: tf_device.return %[[B_OUTPUT]]
|
||||||
|
// CHECK: "tf.C"(%[[LAUNCH_OUTPUT]])
|
||||||
|
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||||
|
%2 = "tf_device.cluster"() ( {
|
||||||
|
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||||
|
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
tf_device.return %5 : tensor<?xi32>
|
||||||
|
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||||
|
tf_device.return %2 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
return %1 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests launch wrap of multiple outside compiled cluster with input/output.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @multiple_outside_compilation_input_output
|
||||||
|
func @multiple_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||||
|
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||||
|
// CHECK: "tf_device.cluster"
|
||||||
|
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||||
|
// CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||||
|
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]])
|
||||||
|
// CHECK: tf_device.return %[[B_OUTPUT]]
|
||||||
|
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[LAUNCH_OUTPUT]])
|
||||||
|
// CHECK-NEXT: %[[LAUNCH_OUTPUT2:[0-9]*]] = "tf_device.launch"
|
||||||
|
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]])
|
||||||
|
// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[D_OUTPUT]])
|
||||||
|
// CHECK: tf_device.return %[[E_OUTPUT]]
|
||||||
|
// CHECK: "tf.F"(%[[LAUNCH_OUTPUT2]])
|
||||||
|
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||||
|
%2 = "tf_device.cluster"() ( {
|
||||||
|
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||||
|
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
%7 = "tf.E"(%6) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
%8 = "tf.F"(%7) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
tf_device.return %8 : tensor<?xi32>
|
||||||
|
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||||
|
tf_device.return %2 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
return %1 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
}
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||||
#include "mlir/IR/Function.h" // from @llvm-project
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
@ -38,6 +39,12 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern {
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(Operation* op,
|
LogicalResult matchAndRewrite(Operation* op,
|
||||||
PatternRewriter& rewriter) const override;
|
PatternRewriter& rewriter) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename Op>
|
||||||
|
LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
|
||||||
|
|
||||||
|
LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
||||||
@ -47,7 +54,27 @@ class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
|||||||
|
|
||||||
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
if (!op->hasTrait<OpTrait::ResultsBroadcastableShape>()) return failure();
|
if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
|
||||||
|
return RewriteOp(op, rewriter);
|
||||||
|
|
||||||
|
// tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
|
||||||
|
// incompatible_shape_error is `true` (what is also checked by the verifier).
|
||||||
|
if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
|
||||||
|
if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
|
||||||
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
|
auto eq_op = llvm::dyn_cast_or_null<Op>(op);
|
||||||
|
if (eq_op && eq_op.incompatible_shape_error()) return RewriteOp(op, rewriter);
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
|
||||||
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
|
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@ -56,6 +83,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
|||||||
op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
|
op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
|
||||||
if (!result_type || !result_type.hasStaticShape()) return failure();
|
if (!result_type || !result_type.hasStaticShape()) return failure();
|
||||||
|
|
||||||
|
bool changed = false;
|
||||||
for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
|
for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
|
||||||
// Check that the i'th operand is a broadcast.
|
// Check that the i'th operand is a broadcast.
|
||||||
auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
|
auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
|
||||||
@ -89,10 +117,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
|||||||
// Update the operand of the op to be the operand of the broadcast.
|
// Update the operand of the op to be the operand of the broadcast.
|
||||||
rewriter.updateRootInPlace(
|
rewriter.updateRootInPlace(
|
||||||
op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
|
op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
|
||||||
return success();
|
changed = true;
|
||||||
}
|
}
|
||||||
|
return success(changed);
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void BroadcastFoldPass::runOnFunction() {
|
void BroadcastFoldPass::runOnFunction() {
|
||||||
|
@ -112,8 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
|
|||||||
LogicalResult ConvertWhileOp(WhileOp while_op) {
|
LogicalResult ConvertWhileOp(WhileOp while_op) {
|
||||||
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
|
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
|
||||||
while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
|
while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
|
||||||
while_op.output_shapes(), while_op.parallel_iterations(),
|
while_op.parallel_iterations(), while_op.is_stateless(),
|
||||||
while_op.is_stateless());
|
while_op.shape_invariant());
|
||||||
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
|
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
|
||||||
|
|
||||||
YieldOp cond_yield =
|
YieldOp cond_yield =
|
||||||
|
@ -21,15 +21,18 @@ limitations under the License.
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "llvm/ADT/APInt.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
@ -44,6 +47,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||||
|
#include "tensorflow/core/lib/math/math_util.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TF {
|
namespace TF {
|
||||||
@ -479,18 +483,16 @@ template <typename ReductionOp>
|
|||||||
LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
|
LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
|
||||||
Block &body = function.front();
|
Block &body = function.front();
|
||||||
if (body.getNumArguments() != 2) return failure();
|
if (body.getNumArguments() != 2) return failure();
|
||||||
if (body.getOperations().size() != 2) return failure();
|
|
||||||
|
|
||||||
ReductionOp reduce_op = dyn_cast<ReductionOp>(body.front());
|
|
||||||
if (!reduce_op) return failure();
|
|
||||||
if (reduce_op.lhs() != body.getArgument(0) ||
|
|
||||||
reduce_op.rhs() != body.getArgument(1))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
|
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
|
||||||
if (!return_op) return failure();
|
if (!return_op) return failure();
|
||||||
if (return_op.getNumOperands() != 1 ||
|
if (return_op.getNumOperands() != 1) return failure();
|
||||||
return_op.results().front() != reduce_op)
|
|
||||||
|
ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
|
||||||
|
return_op.getOperands().front().getDefiningOp());
|
||||||
|
if (!reduce_op) return failure();
|
||||||
|
if (reduce_op.lhs() != body.getArgument(0) ||
|
||||||
|
reduce_op.rhs() != body.getArgument(1))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@ -654,6 +656,190 @@ class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Maps the following represenattions of AvgPool in MHLO into a tf.AvgPool{3D}
|
||||||
|
// operation when they cleanly map to 2D or 3D average pool with VALID or SAME
|
||||||
|
// padding:
|
||||||
|
// * div(reduce_sum_window(x), constant(sizeof(window)))
|
||||||
|
// * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
|
||||||
|
class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::DivOp div_op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto rw =
|
||||||
|
dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
|
||||||
|
if (!rw) return failure();
|
||||||
|
|
||||||
|
// Check that the reduce-window is a sum-reduce-window.
|
||||||
|
if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Check that this is a floating point reduce window with a rank of 4 or 5.
|
||||||
|
RankedTensorType rw_type = rw.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
|
||||||
|
rw_type.getRank() <= 3 || rw_type.getRank() > 5)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Check that the Div op doesn't do broadcasting on the output of the reduce
|
||||||
|
// window.
|
||||||
|
if (div_op.getType() != rw.getType()) return failure();
|
||||||
|
|
||||||
|
// tf.avg_pool need at least 3 dimensions (batch, spatial, channel)
|
||||||
|
const uint64_t rank = rw.window_dimensions().size();
|
||||||
|
if (rank <= 2) return failure();
|
||||||
|
|
||||||
|
// If the init value isn't zero then it can't be an average pool.
|
||||||
|
if (!isFloatZero(rw.init_value())) return failure();
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t, 5> window_strides;
|
||||||
|
if (rw.window_strides().hasValue()) {
|
||||||
|
window_strides.insert(window_strides.end(),
|
||||||
|
rw.window_strides()->getValues<int64_t>().begin(),
|
||||||
|
rw.window_strides()->getValues<int64_t>().end());
|
||||||
|
} else {
|
||||||
|
window_strides.resize(rank, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t, 10> padding;
|
||||||
|
if (rw.padding().hasValue()) {
|
||||||
|
padding.insert(padding.begin(),
|
||||||
|
rw.padding()->getValues<int64_t>().begin(),
|
||||||
|
rw.padding()->getValues<int64_t>().end());
|
||||||
|
} else {
|
||||||
|
padding.resize(2 * rank, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we don't do any reduction along the batch (first) and channel
|
||||||
|
// (last) dimensions.
|
||||||
|
const uint64_t batch_dim = 0;
|
||||||
|
const uint64_t channel_dim = rank - 1;
|
||||||
|
if (rw.window_dimensions().getValue<int64_t>({batch_dim}) != 1 ||
|
||||||
|
rw.window_dimensions().getValue<int64_t>({channel_dim}) != 1 ||
|
||||||
|
window_strides[batch_dim] != 1 || window_strides[channel_dim] != 1 ||
|
||||||
|
padding[2 * batch_dim] != 0 || padding[2 * batch_dim + 1] != 0 ||
|
||||||
|
padding[2 * channel_dim] != 0 || padding[2 * channel_dim + 1] != 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (rw.window_dilations().hasValue() &&
|
||||||
|
!(rw.window_dilations()->isSplat() &&
|
||||||
|
rw.window_dilations()->getSplatValue<APInt>() == 1))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (rw.base_dilations().hasValue() &&
|
||||||
|
!(rw.base_dilations()->isSplat() &&
|
||||||
|
rw.base_dilations()->getSplatValue<APInt>() == 1))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
DenseFPElementsAttr divisor;
|
||||||
|
if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
|
||||||
|
// If the divisor is a constant then check that it matches with the number
|
||||||
|
// of elements inside the window what is required for a VALID AvgPool.
|
||||||
|
if (!divisor.isSplat()) return failure();
|
||||||
|
int64_t window_size = 1;
|
||||||
|
for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
|
||||||
|
window_size *= w;
|
||||||
|
}
|
||||||
|
if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Check that we have no padding.
|
||||||
|
if (!llvm::all_of(padding, [](int64_t i) { return i == 0; }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return replaceWithAvgPool(
|
||||||
|
div_op, rw.operand(),
|
||||||
|
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
|
||||||
|
window_strides, "VALID", rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rw_rhs =
|
||||||
|
dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
|
||||||
|
if (rw_rhs) {
|
||||||
|
// Check that RHS is a sum-reduce-window.
|
||||||
|
if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Check that the RHS is a reduce_window over a constant 1 input with 0 as
|
||||||
|
// the init value.
|
||||||
|
DenseFPElementsAttr rhs_input;
|
||||||
|
if (!isFloatZero(rw_rhs.init_value()) ||
|
||||||
|
!matchPattern(rw_rhs.operand(), m_Constant(&rhs_input)) ||
|
||||||
|
!rhs_input.isSplat() ||
|
||||||
|
!rhs_input.getSplatValue<APFloat>().isExactlyValue(1.0))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Check that the two reduce window have the same window configuration.
|
||||||
|
if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
|
||||||
|
rw.window_strides() != rw_rhs.window_strides() ||
|
||||||
|
rw.window_dilations() != rw_rhs.window_dilations() ||
|
||||||
|
rw.base_dilations() != rw_rhs.base_dilations() ||
|
||||||
|
rw.padding() != rw_rhs.padding())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (llvm::all_of(padding, [](int64_t i) { return i == 0; }))
|
||||||
|
return replaceWithAvgPool(
|
||||||
|
div_op, rw.operand(),
|
||||||
|
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
|
||||||
|
window_strides, "VALID", rewriter);
|
||||||
|
|
||||||
|
RankedTensorType input_type =
|
||||||
|
rw.operand().getType().dyn_cast<RankedTensorType>();
|
||||||
|
RankedTensorType output_type = rw.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!input_type || !output_type) return failure();
|
||||||
|
|
||||||
|
// Check that the individual padding values are corresponding to SAME
|
||||||
|
// padding from TensorFlow.
|
||||||
|
for (uint64_t i = 1; i < rank - 1; ++i) {
|
||||||
|
int64_t padding_size =
|
||||||
|
(output_type.getShape()[i] - 1) * window_strides[i] +
|
||||||
|
rw.window_dimensions().getValue<int64_t>({i}) -
|
||||||
|
input_type.getShape()[i];
|
||||||
|
if (padding[2 * i] !=
|
||||||
|
tensorflow::MathUtil::FloorOfRatio(padding_size, int64_t(2)) ||
|
||||||
|
padding[2 * i + 1] !=
|
||||||
|
tensorflow::MathUtil::CeilOfRatio(padding_size, int64_t(2)))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return replaceWithAvgPool(
|
||||||
|
div_op, rw.operand(),
|
||||||
|
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
|
||||||
|
window_strides, "SAME", rewriter);
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool isFloatZero(Value value) const {
|
||||||
|
DenseFPElementsAttr initial_value;
|
||||||
|
return matchPattern(value, m_Constant(&initial_value)) &&
|
||||||
|
initial_value.getNumElements() == 1 &&
|
||||||
|
initial_value.getValue<APFloat>({}).isZero();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
|
||||||
|
llvm::ArrayRef<int64_t> ksizes,
|
||||||
|
llvm::ArrayRef<int64_t> kstrides,
|
||||||
|
llvm::StringRef padding,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
if (ksizes.size() == 4) {
|
||||||
|
rewriter.replaceOpWithNewOp<AvgPoolOp>(
|
||||||
|
op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
|
||||||
|
rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
|
||||||
|
rewriter.getStringAttr("NHWC"));
|
||||||
|
return success();
|
||||||
|
} else if (ksizes.size() == 5) {
|
||||||
|
rewriter.replaceOpWithNewOp<AvgPool3DOp>(
|
||||||
|
op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
|
||||||
|
rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
|
||||||
|
rewriter.getStringAttr("NDHWC"));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<TF::TensorFlowDialect>();
|
registry.insert<TF::TensorFlowDialect>();
|
||||||
@ -794,10 +980,10 @@ static PassRegistration<LegalizeHloToTf> pass(
|
|||||||
|
|
||||||
void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
|
void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
|
patterns->insert<ConvertAvgPoolOp, ConvertConvOp, ConvertSliceOp,
|
||||||
|
ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
|
||||||
|
ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(context);
|
||||||
populateWithGenerated(context, *patterns);
|
populateWithGenerated(context, *patterns);
|
||||||
patterns->insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
|
|
||||||
ConvertReduceOpToTfMin, ConvertReduceOpToTfSum,
|
|
||||||
ConvertIotaOpToTfRange>(context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
|
||||||
|
@ -332,12 +332,13 @@ class LowerDynamicStitchOp : public RewritePattern {
|
|||||||
class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
|
class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
|
||||||
public:
|
public:
|
||||||
explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
|
explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
|
||||||
: RewritePattern(FakeQuantWithMinMaxVarsOp::getOperationName(),
|
: RewritePattern(
|
||||||
{SubOp::getOperationName(), ConstOp::getOperationName(),
|
FakeQuantWithMinMaxVarsOp::getOperationName(),
|
||||||
MulOp::getOperationName(), FloorOp::getOperationName(),
|
{AddV2Op::getOperationName(), SubOp::getOperationName(),
|
||||||
ClipByValueOp::getOperationName(),
|
ConstOp::getOperationName(), MulOp::getOperationName(),
|
||||||
DivOp::getOperationName(), RoundOp::getOperationName()},
|
FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
|
||||||
1, context) {}
|
DivOp::getOperationName(), RoundOp::getOperationName()},
|
||||||
|
1, context) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(Operation *src_op,
|
LogicalResult matchAndRewrite(Operation *src_op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
@ -419,8 +420,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
|
|||||||
op.getLoc(),
|
op.getLoc(),
|
||||||
DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
|
DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
|
||||||
|
|
||||||
quantized_input = rewriter.create<AddOp>(op.getLoc(), input_ty,
|
quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
|
||||||
quantized_input, half_val);
|
quantized_input, half_val);
|
||||||
|
|
||||||
quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
|
quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
|
||||||
|
|
||||||
@ -428,8 +429,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
|
|||||||
Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
|
Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
|
||||||
quantized_input, quant_to_float);
|
quantized_input, quant_to_float);
|
||||||
|
|
||||||
output =
|
output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
|
||||||
rewriter.create<AddOp>(op.getLoc(), input_ty, output, nudged_float_min);
|
nudged_float_min);
|
||||||
|
|
||||||
rewriter.replaceOp(op, {output});
|
rewriter.replaceOp(op, {output});
|
||||||
return success();
|
return success();
|
||||||
@ -811,7 +812,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
|||||||
CastOp::getOperationName(),
|
CastOp::getOperationName(),
|
||||||
ConstOp::getOperationName(),
|
ConstOp::getOperationName(),
|
||||||
ConcatV2Op::getOperationName(),
|
ConcatV2Op::getOperationName(),
|
||||||
AddOp::getOperationName(),
|
AddV2Op::getOperationName(),
|
||||||
PadOp::getOperationName(),
|
PadOp::getOperationName(),
|
||||||
SplitOp::getOperationName(),
|
SplitOp::getOperationName(),
|
||||||
UnpackOp::getOperationName(),
|
UnpackOp::getOperationName(),
|
||||||
@ -907,8 +908,8 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
|||||||
auto paddings_split = rewriter.create<UnpackOp>(
|
auto paddings_split = rewriter.create<UnpackOp>(
|
||||||
loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
|
loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
|
||||||
rewriter.getI64IntegerAttr(1));
|
rewriter.getI64IntegerAttr(1));
|
||||||
auto paddings_sum = rewriter.create<AddOp>(loc, paddings_split.getResult(0),
|
auto paddings_sum = rewriter.create<AddV2Op>(
|
||||||
paddings_split.getResult(1));
|
loc, paddings_split.getResult(0), paddings_split.getResult(1));
|
||||||
|
|
||||||
auto input_shape_tensor = rewriter.create<ConstOp>(
|
auto input_shape_tensor = rewriter.create<ConstOp>(
|
||||||
loc,
|
loc,
|
||||||
@ -918,7 +919,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
|||||||
|
|
||||||
// padded_shape_tensor is the shape of padded.
|
// padded_shape_tensor is the shape of padded.
|
||||||
auto padded_shape_tensor =
|
auto padded_shape_tensor =
|
||||||
rewriter.create<AddOp>(loc, paddings_sum, input_shape_tensor);
|
rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
|
||||||
|
|
||||||
auto zero_i32 = rewriter.create<ConstOp>(
|
auto zero_i32 = rewriter.create<ConstOp>(
|
||||||
loc, GetScalarOfType(rewriter.getIntegerType(32), 0));
|
loc, GetScalarOfType(rewriter.getIntegerType(32), 0));
|
||||||
|
@ -237,7 +237,7 @@ def : Pat<(TF_RoundOp:$res TF_FloatTensor:$input),
|
|||||||
(TF_SubOp $input, (TF_FloorOp:$floor $input)),
|
(TF_SubOp $input, (TF_FloorOp:$floor $input)),
|
||||||
(TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))),
|
(TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))),
|
||||||
$floor,
|
$floor,
|
||||||
(TF_AddOp
|
(TF_AddV2Op
|
||||||
(TF_ConstOp (GetScalarOfType<1> $input)), $floor))>;
|
(TF_ConstOp (GetScalarOfType<1> $input)), $floor))>;
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,184 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TFTPU {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kDeviceAttr[] = "device";
|
||||||
|
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
|
||||||
|
|
||||||
|
// Mapping for `_xla_outside_compilation` attribute to ops of a cluster.
|
||||||
|
using OutsideClusterMap =
|
||||||
|
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<Operation*, 8>, 8>;
|
||||||
|
|
||||||
|
// This pass wraps ops with the same `_xla_outside_compilation`
|
||||||
|
// attribute value in a tf_device.launch op with host device assignment.
|
||||||
|
//
|
||||||
|
// A simple example:
|
||||||
|
// "tf_device.cluster"() ( {
|
||||||
|
// "tf.A"()
|
||||||
|
// "tf.B"() {_xla_outside_compilation = "cluster1"}
|
||||||
|
// "tf.C"()
|
||||||
|
// tf_device.return
|
||||||
|
// }) {num_cores_per_replica = 1, topology = "", device_assignment = []}
|
||||||
|
//
|
||||||
|
// Would become the following ops (unimportant attribute, type are omitted):
|
||||||
|
// "tf_device.cluster"() ( {
|
||||||
|
// "tf.A"()
|
||||||
|
// "tf_device.launch"() {
|
||||||
|
// "tf.B"() {_xla_outside_compilation = "cluster1"}
|
||||||
|
// tf_device.return
|
||||||
|
// } {device = "TPU_REPLICATED_HOST"} : () -> ()
|
||||||
|
// "tf.C"()
|
||||||
|
// tf_device.return
|
||||||
|
// }) {num_cores_per_replica = 1, topology = "", device_assignment = []}
|
||||||
|
//
|
||||||
|
|
||||||
|
struct OutsideCompiledToHostLaunch
|
||||||
|
: public PassWrapper<OutsideCompiledToHostLaunch, OperationPass<ModuleOp>> {
|
||||||
|
void runOnOperation() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
|
||||||
|
// attribute into `clusters` This returns an error if a
|
||||||
|
// `_xla_outside_compilation` attribute of an op is empty.
|
||||||
|
LogicalResult CollectAndGroupOutsideClusterOps(Block* block,
|
||||||
|
OutsideClusterMap* clusters) {
|
||||||
|
auto walk_result = block->walk([&](Operation* op) {
|
||||||
|
if (auto attr = op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
|
||||||
|
if (attr.getValue().empty()) {
|
||||||
|
op->emitOpError() << "requires non empty '"
|
||||||
|
<< kXlaOutsideCompilationAttr << "' string attribute";
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto it = clusters->try_emplace(attr.getValue());
|
||||||
|
it.first->getSecond().push_back(op);
|
||||||
|
}
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
|
||||||
|
return failure(walk_result.wasInterrupted());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracts all externally used outputs of `cluster_ops`.
|
||||||
|
llvm::SmallSetVector<Value, 4> GetExternalOutputs(
|
||||||
|
llvm::ArrayRef<Operation*> cluster_ops) {
|
||||||
|
llvm::SmallSetVector<Value, 4> external_outputs;
|
||||||
|
llvm::SmallPtrSet<Operation*, 4> host_cluster_ops_set;
|
||||||
|
for (auto op : cluster_ops) {
|
||||||
|
op->walk([&](Operation* host_cluster_op) {
|
||||||
|
host_cluster_ops_set.insert(host_cluster_op);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation* op : cluster_ops) {
|
||||||
|
for (Operation* user : op->getUsers()) {
|
||||||
|
bool is_external = llvm::none_of(
|
||||||
|
host_cluster_ops_set,
|
||||||
|
[&](Operation* cluster_op) { return user == cluster_op; });
|
||||||
|
if (!is_external) continue;
|
||||||
|
for (Value v : user->getOperands()) {
|
||||||
|
if (v.getDefiningOp() == op) external_outputs.insert(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return external_outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
void WrapClusterInLaunch(llvm::ArrayRef<Operation*> cluster_ops,
|
||||||
|
llvm::StringRef host_device) {
|
||||||
|
auto* last_cluster_op = cluster_ops.back();
|
||||||
|
OpBuilder builder(last_cluster_op);
|
||||||
|
llvm::SmallVector<Type, 4> launch_output_types;
|
||||||
|
auto external_outputs = GetExternalOutputs(cluster_ops);
|
||||||
|
for (const auto& external_output : external_outputs)
|
||||||
|
launch_output_types.push_back(external_output.getType());
|
||||||
|
|
||||||
|
auto launch_op = builder.create<tf_device::LaunchOp>(
|
||||||
|
last_cluster_op->getLoc(), builder.getStringAttr(host_device),
|
||||||
|
/*result_types=*/launch_output_types);
|
||||||
|
for (auto result : llvm::zip(external_outputs, launch_op.getResults())) {
|
||||||
|
std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
launch_op.body().push_back(new Block);
|
||||||
|
builder.setInsertionPointToEnd(&launch_op.GetBody());
|
||||||
|
auto* return_op =
|
||||||
|
builder
|
||||||
|
.create<tf_device::ReturnOp>(last_cluster_op->getLoc(),
|
||||||
|
external_outputs.getArrayRef())
|
||||||
|
.getOperation();
|
||||||
|
MLIRContext* context = launch_op.getContext();
|
||||||
|
for (Operation* cluster_op : cluster_ops) {
|
||||||
|
cluster_op->removeAttr(
|
||||||
|
Identifier::get(kXlaOutsideCompilationAttr, context));
|
||||||
|
cluster_op->removeAttr(Identifier::get(kDeviceAttr, context));
|
||||||
|
cluster_op->moveBefore(return_op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OutsideCompiledToHostLaunch::runOnOperation() {
|
||||||
|
// Get runtime devices information from the closest parent module.
|
||||||
|
auto module = getOperation();
|
||||||
|
mlir::TF::RuntimeDevices devices;
|
||||||
|
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
|
||||||
|
return signalPassFailure();
|
||||||
|
|
||||||
|
auto result = module.walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||||
|
OutsideClusterMap clusters;
|
||||||
|
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
|
||||||
|
&clusters)))
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
|
||||||
|
if (clusters.empty()) return WalkResult::advance();
|
||||||
|
|
||||||
|
std::string host_device;
|
||||||
|
tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster,
|
||||||
|
&host_device);
|
||||||
|
for (const auto& cluster : clusters) {
|
||||||
|
WrapClusterInLaunch(cluster.getSecond(), host_device);
|
||||||
|
}
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
if (result.wasInterrupted()) return signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
CreateOutsideCompiledToHostLaunchPass() {
|
||||||
|
return std::make_unique<OutsideCompiledToHostLaunch>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<OutsideCompiledToHostLaunch> pass(
|
||||||
|
"tf-outside-compiled-to-host-launch",
|
||||||
|
"Wraps ops with ithe same _xla_outside_compiled attribute in "
|
||||||
|
"tf_device.launch on replicated host device.");
|
||||||
|
|
||||||
|
} // namespace TFTPU
|
||||||
|
} // namespace mlir
|
@ -345,6 +345,11 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUColocateCompositeResourceOps();
|
|||||||
// run-time according to compilation result.
|
// run-time according to compilation result.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass();
|
||||||
|
|
||||||
|
// Creates a pass that wraps ops with the same `_xla_outside_compilation`
|
||||||
|
// attribute value in a tf_device.launch op with host device assignment.
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
CreateOutsideCompiledToHostLaunchPass();
|
||||||
|
|
||||||
// Creates a pass that groups outside compiled operations (CPU ops inside TPU
|
// Creates a pass that groups outside compiled operations (CPU ops inside TPU
|
||||||
// cluster) into clusters that can be extracted and run on the CPU.
|
// cluster) into clusters that can be extracted and run on the CPU.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
@ -398,8 +398,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
|
|||||||
OpBuilder builder(while_region);
|
OpBuilder builder(while_region);
|
||||||
auto while_op = builder.create<WhileOp>(
|
auto while_op = builder.create<WhileOp>(
|
||||||
while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
|
while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
|
||||||
while_region.output_shapes(), while_region.parallel_iterations(),
|
while_region.parallel_iterations(), while_region.is_stateless(),
|
||||||
while_region.is_stateless());
|
while_region.shape_invariant());
|
||||||
CopyDeviceAndUnderscoredAttributes(while_region, while_op);
|
CopyDeviceAndUnderscoredAttributes(while_region, while_op);
|
||||||
|
|
||||||
// Redirect old results to new results.
|
// Redirect old results to new results.
|
||||||
|
@ -255,8 +255,7 @@ TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
|
|||||||
OpBuilder& builder) {
|
OpBuilder& builder) {
|
||||||
auto host_side_while = builder.create<TF::WhileRegionOp>(
|
auto host_side_while = builder.create<TF::WhileRegionOp>(
|
||||||
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
|
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
|
||||||
/*output_shapes=*/builder.getArrayAttr({}), parallel_iterations,
|
parallel_iterations, is_stateless, /*shape_invariant=*/false);
|
||||||
is_stateless);
|
|
||||||
|
|
||||||
// Create empty else branch region.
|
// Create empty else branch region.
|
||||||
auto& body = host_side_while.body();
|
auto& body = host_side_while.body();
|
||||||
|
@ -155,6 +155,9 @@ StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore(
|
|||||||
if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst))
|
if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst))
|
||||||
attrs_to_ignore.insert("is_stateless");
|
attrs_to_ignore.insert("is_stateless");
|
||||||
|
|
||||||
|
if (llvm::isa<mlir::TF::WhileOp>(inst))
|
||||||
|
attrs_to_ignore.insert("shape_invariant");
|
||||||
|
|
||||||
return attrs_to_ignore;
|
return attrs_to_ignore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -971,6 +971,16 @@ StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
|
|||||||
etype.getContext()));
|
etype.getContext()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node.IsWhileNode()) {
|
||||||
|
auto* output_shapes = node.attrs().Find("output_shapes");
|
||||||
|
auto* element_types = node.attrs().Find("T");
|
||||||
|
if (output_shapes && !output_shapes->list().shape().empty()) {
|
||||||
|
const auto& output_shape = output_shapes->list().shape(idx);
|
||||||
|
const auto& element_type = element_types->list().type(idx);
|
||||||
|
return ConvertToMlirTensorType(output_shape, element_type, &builder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Returns a simple, more conservative unranked tensor type.
|
// Returns a simple, more conservative unranked tensor type.
|
||||||
auto default_type = [&]() -> StatusOr<mlir::Type> {
|
auto default_type = [&]() -> StatusOr<mlir::Type> {
|
||||||
mlir::Type element_type;
|
mlir::Type element_type;
|
||||||
@ -1907,7 +1917,13 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
|||||||
// Case/If/While op in MLIR and add the differentiating attribute.
|
// Case/If/While op in MLIR and add the differentiating attribute.
|
||||||
if (node.IsCaseNode()) composite_control_flow_op("Case");
|
if (node.IsCaseNode()) composite_control_flow_op("Case");
|
||||||
if (node.IsIfNode()) composite_control_flow_op("If");
|
if (node.IsIfNode()) composite_control_flow_op("If");
|
||||||
if (node.IsWhileNode()) composite_control_flow_op("While");
|
if (node.IsWhileNode()) {
|
||||||
|
composite_control_flow_op("While");
|
||||||
|
auto* output_shapes = node.attrs().Find("output_shapes");
|
||||||
|
if (output_shapes && !output_shapes->list().shape().empty())
|
||||||
|
result.attributes.push_back(
|
||||||
|
builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr()));
|
||||||
|
}
|
||||||
|
|
||||||
// Register the mapping between the TF node and the newly created operation.
|
// Register the mapping between the TF node and the newly created operation.
|
||||||
node_values_[node.id()] =
|
node_values_[node.id()] =
|
||||||
@ -3420,6 +3436,8 @@ SavedModelSignatureDefImporterLite::ConvertGraph(
|
|||||||
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
|
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
|
||||||
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
|
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
|
||||||
const std::vector<std::string> control_outputs) {
|
const std::vector<std::string> control_outputs) {
|
||||||
|
VLOG(1) << "Importing Signature: " << name;
|
||||||
|
|
||||||
GraphImportConfig specs;
|
GraphImportConfig specs;
|
||||||
specs.prune_unused_nodes = true;
|
specs.prune_unused_nodes = true;
|
||||||
specs.inputs = ParseInputArrays(inputs);
|
specs.inputs = ParseInputArrays(inputs);
|
||||||
@ -3491,6 +3509,9 @@ SavedModelSignatureDefImporterLite::ParseInputArrays(
|
|||||||
// Only dense tensor is supported.
|
// Only dense tensor is supported.
|
||||||
DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName);
|
DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName);
|
||||||
|
|
||||||
|
VLOG(1) << "Importing Signature Input: input_name = " << iter.first
|
||||||
|
<< ", tensor_info = " << tensor_info.DebugString();
|
||||||
|
|
||||||
ArrayInfo array_info;
|
ArrayInfo array_info;
|
||||||
array_info.imported_dtype = tensor_info.dtype();
|
array_info.imported_dtype = tensor_info.dtype();
|
||||||
array_info.shape = tensor_info.tensor_shape();
|
array_info.shape = tensor_info.tensor_shape();
|
||||||
|
@ -68,6 +68,7 @@ distribute_py_test(
|
|||||||
tags = [
|
tags = [
|
||||||
"no_cuda_asan", # b/173431253
|
"no_cuda_asan", # b/173431253
|
||||||
"no_oss",
|
"no_oss",
|
||||||
|
"notap", # b/173661843
|
||||||
"notsan", # b/173246447
|
"notsan", # b/173246447
|
||||||
],
|
],
|
||||||
xla_enable_strict_auto_jit = False, # b/173254861
|
xla_enable_strict_auto_jit = False, # b/173254861
|
||||||
|
@ -90,8 +90,17 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Legalize only hlo operations to lhlo, keep the rest as tensors.
|
// Partial bufferization: Transforms inparticular HLO operation to their
|
||||||
pm.addPass(mlir::kernel_gen::transforms::CreateHloBufferizePass());
|
// corresponding LHLO operations and converts the function signature. Leaves
|
||||||
|
// shape operations untouched.
|
||||||
|
pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass(
|
||||||
|
/*allow_partial_bufferization=*/true));
|
||||||
|
// Run CSE to ensure that loads and stores to the same location get recognized
|
||||||
|
// as such.
|
||||||
|
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||||
|
// Forward stores to buffers to loads.
|
||||||
|
pm.addNestedPass<mlir::FuncOp>(xla::mlir_gpu::createStoreForwardingPass());
|
||||||
|
|
||||||
// Clean up the IR for further processing.
|
// Clean up the IR for further processing.
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||||
@ -100,18 +109,22 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
// needed.
|
// needed.
|
||||||
llvm::SmallVector<unsigned, 4> tiling_for_unrolling;
|
llvm::SmallVector<unsigned, 4> tiling_for_unrolling;
|
||||||
llvm::SmallVector<int64_t, 4> as_int64;
|
llvm::SmallVector<int64_t, 4> as_int64;
|
||||||
if (!unroll_factors.empty()) {
|
tiling_for_unrolling.reserve(tile_sizes.size());
|
||||||
tiling_for_unrolling.reserve(tile_sizes.size());
|
for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
|
||||||
for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
|
tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
|
||||||
tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
|
as_int64.push_back(std::get<1>(pair));
|
||||||
as_int64.push_back(std::get<1>(pair));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
|
|
||||||
}
|
}
|
||||||
|
tiling_for_unrolling.append(
|
||||||
|
tile_sizes.drop_front(unroll_factors.size()).begin(), tile_sizes.end());
|
||||||
// Transform LHLO operations to LinAlg.
|
// Transform LHLO operations to LinAlg.
|
||||||
pm.addNestedPass<mlir::FuncOp>(
|
pm.addNestedPass<mlir::FuncOp>(
|
||||||
::mlir::lmhlo::createLegalizeLhloToLinalgPass());
|
::mlir::lmhlo::createLegalizeLhloToLinalgPass());
|
||||||
|
if (!gpu_binary_only) {
|
||||||
|
// Find candidates for buffer reuse. This is only successful if buffer size
|
||||||
|
// equality can be determined based on `linalg.generic` operations.
|
||||||
|
pm.addNestedPass<mlir::FuncOp>(
|
||||||
|
mlir::kernel_gen::transforms::CreateBufferReusePass());
|
||||||
|
}
|
||||||
// Fuse linalg operations.
|
// Fuse linalg operations.
|
||||||
pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass(
|
pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass(
|
||||||
/*use_parallel_loops=*/true, tiling_for_unrolling));
|
/*use_parallel_loops=*/true, tiling_for_unrolling));
|
||||||
@ -141,11 +154,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
// Some basic cleanup.
|
// Some basic cleanup.
|
||||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||||
if (!gpu_binary_only) {
|
|
||||||
// Find candidates for buffer reuse.
|
|
||||||
pm.addNestedPass<mlir::FuncOp>(
|
|
||||||
mlir::kernel_gen::transforms::CreateBufferReusePass());
|
|
||||||
}
|
|
||||||
// Greedily map the remaining loop to GPU hardware dimensions.
|
// Greedily map the remaining loop to GPU hardware dimensions.
|
||||||
pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass());
|
pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass());
|
||||||
|
|
||||||
@ -157,7 +165,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
|
pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||||
pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass());
|
pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass());
|
||||||
pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass(64));
|
pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass(64));
|
||||||
// TODO(herhut): Enabled this to avoid leaks once fixed.
|
// TODO(herhut): Enabled this to avoid leaks once fixed.
|
||||||
// pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass());
|
// pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass());
|
||||||
@ -189,10 +197,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
pm.addPass(::mlir::createLowerAffinePass());
|
pm.addPass(::mlir::createLowerAffinePass());
|
||||||
// Constraints are removed as late as possible and before lowering to CFG.
|
// Constraints are removed as late as possible and before lowering to CFG.
|
||||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createConvertShapeConstraintsPass());
|
pm.addNestedPass<::mlir::FuncOp>(::mlir::createConvertShapeConstraintsPass());
|
||||||
if (embed_memref_prints) {
|
|
||||||
pm.addNestedPass<::mlir::FuncOp>(
|
|
||||||
mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass());
|
|
||||||
}
|
|
||||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||||
// TODO(herhut): Remove this pass once the LowerToCFG pass can handle it.
|
// TODO(herhut): Remove this pass once the LowerToCFG pass can handle it.
|
||||||
pm.addNestedPass<mlir::FuncOp>(
|
pm.addNestedPass<mlir::FuncOp>(
|
||||||
@ -200,6 +204,10 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
pm.addPass(::mlir::createLowerToCFGPass());
|
pm.addPass(::mlir::createLowerToCFGPass());
|
||||||
// Map allocs, asserts, etc. to the tensorflow framework.
|
// Map allocs, asserts, etc. to the tensorflow framework.
|
||||||
pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass());
|
pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass());
|
||||||
|
if (embed_memref_prints) {
|
||||||
|
pm.addNestedPass<::mlir::FuncOp>(
|
||||||
|
mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass());
|
||||||
|
}
|
||||||
if (failed(pm.run(module))) {
|
if (failed(pm.run(module))) {
|
||||||
return InternalError("Lowering to GPU kernels failed.");
|
return InternalError("Lowering to GPU kernels failed.");
|
||||||
}
|
}
|
||||||
|
@ -49,12 +49,58 @@ func @local_reuse_with_memref_maps(
|
|||||||
iterator_types = ["parallel"]
|
iterator_types = ["parallel"]
|
||||||
} ins(%arg : memref<?xi64, offset: 2, strides: [3]>)
|
} ins(%arg : memref<?xi64, offset: 2, strides: [3]>)
|
||||||
outs(%result : memref<?xi64, offset: 2, strides: [3]>) {
|
outs(%result : memref<?xi64, offset: 2, strides: [3]>) {
|
||||||
^bb0(%a: i64, %b: i64):
|
^bb0(%a : i64, %b : i64):
|
||||||
linalg.yield %a : i64
|
linalg.yield %a : i64
|
||||||
}
|
}
|
||||||
return %result : memref<?xi64, offset: 2, strides: [3]>
|
return %result : memref<?xi64, offset: 2, strides: [3]>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @memref_reinterpret_cast_alias
|
||||||
|
func @memref_reinterpret_cast_alias(%arg : memref<f32>, %n : index)
|
||||||
|
-> memref<?xf32> attributes {tf_entry} {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%reinterpreted = memref_reinterpret_cast %arg to
|
||||||
|
offset: [0],
|
||||||
|
sizes: [%n],
|
||||||
|
strides: [%c0]: memref<f32> to memref<?xf32>
|
||||||
|
|
||||||
|
// CHECK: alloc
|
||||||
|
// CHECK-SAME: reuse_input_candidates = [0 : index]
|
||||||
|
%result = alloc(%n) : memref<?xf32>
|
||||||
|
|
||||||
|
// reinterpreted (arg) and result are of same size.
|
||||||
|
linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||||
|
iterator_types = ["parallel"]
|
||||||
|
} ins(%reinterpreted : memref<?xf32>) outs(%result : memref<?xf32>) {
|
||||||
|
^bb0(%a : f32, %b : f32):
|
||||||
|
linalg.yield %a : f32
|
||||||
|
}
|
||||||
|
|
||||||
|
return %result : memref<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @memref_cast_alias
|
||||||
|
func @memref_cast_alias(%arg : memref<*xf32>, %n : index)
|
||||||
|
-> memref<?xf32> attributes {tf_entry} {
|
||||||
|
%casted = memref_cast %arg : memref<*xf32> to memref<?xf32>
|
||||||
|
|
||||||
|
// CHECK: alloc
|
||||||
|
// CHECK-SAME: reuse_input_candidates = [0 : index]
|
||||||
|
%result = alloc(%n) : memref<?xf32>
|
||||||
|
|
||||||
|
// reinterpreted (arg) and result are of same size.
|
||||||
|
linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||||
|
iterator_types = ["parallel"]
|
||||||
|
} ins(%casted : memref<?xf32>) outs(%result : memref<?xf32>) {
|
||||||
|
^bb0(%a : f32, %b : f32):
|
||||||
|
linalg.yield %a : f32
|
||||||
|
}
|
||||||
|
|
||||||
|
return %result : memref<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @indirect_size_equality
|
// CHECK-LABEL: @indirect_size_equality
|
||||||
func @indirect_size_equality(%arg0 : memref<?xi64>,
|
func @indirect_size_equality(%arg0 : memref<?xi64>,
|
||||||
%arg1 : memref<?xi64>,
|
%arg1 : memref<?xi64>,
|
||||||
@ -65,7 +111,7 @@ func @indirect_size_equality(%arg0 : memref<?xi64>,
|
|||||||
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||||
iterator_types = ["parallel"]
|
iterator_types = ["parallel"]
|
||||||
} ins(%arg0 : memref<?xi64>) outs(%arg1 : memref<?xi64>) {
|
} ins(%arg0 : memref<?xi64>) outs(%arg1 : memref<?xi64>) {
|
||||||
^bb0(%a: i64, %b: i64):
|
^bb0(%a : i64, %b : i64):
|
||||||
linalg.yield %a : i64
|
linalg.yield %a : i64
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,7 +124,7 @@ func @indirect_size_equality(%arg0 : memref<?xi64>,
|
|||||||
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||||
iterator_types = ["parallel"]
|
iterator_types = ["parallel"]
|
||||||
} ins(%arg0 : memref<?xi64>) outs(%result : memref<?xi64>) {
|
} ins(%arg0 : memref<?xi64>) outs(%result : memref<?xi64>) {
|
||||||
^bb0(%a: i64, %b: i64):
|
^bb0(%a : i64, %b : i64):
|
||||||
linalg.yield %a : i64
|
linalg.yield %a : i64
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -317,7 +363,7 @@ func @abs_unranked_i64(%arg : memref<*xi64>,
|
|||||||
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||||
iterator_types = ["parallel"]
|
iterator_types = ["parallel"]
|
||||||
} ins(%flat_arg : memref<?xi64>) outs(%flat_result : memref<?xi64>) {
|
} ins(%flat_arg : memref<?xi64>) outs(%flat_result : memref<?xi64>) {
|
||||||
^bb0(%a: i64, %b: i64):
|
^bb0(%a : i64, %b : i64):
|
||||||
%c0 = constant 0 : i64
|
%c0 = constant 0 : i64
|
||||||
%a_pos = cmpi "sge", %a, %c0 : i64
|
%a_pos = cmpi "sge", %a, %c0 : i64
|
||||||
%a_neg = subi %c0, %a : i64
|
%a_neg = subi %c0, %a : i64
|
||||||
@ -360,3 +406,41 @@ func @index_element_type(%arg : memref<2x3xindex>) -> memref<2x3xindex>
|
|||||||
%result = alloc() : memref<2x3xindex>
|
%result = alloc() : memref<2x3xindex>
|
||||||
return %result : memref<2x3xindex>
|
return %result : memref<2x3xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Example as it occurs in the `tf.Abs` kernel for `f32`.
|
||||||
|
// CHECK-LABEL: @abs_f32
|
||||||
|
func @abs_f32(%arg0: memref<*xf32>) -> memref<*xf32>
|
||||||
|
attributes {llvm.emit_c_interface, tf_entry} {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%0 = shape.shape_of %arg0 : memref<*xf32> -> tensor<?xindex>
|
||||||
|
%1 = shape.num_elements %0 : tensor<?xindex> -> index
|
||||||
|
// CHECK-LABEL: alloc
|
||||||
|
// CHECK-SAME: reuse_input_candidates = []
|
||||||
|
%2 = alloc() : memref<1xindex>
|
||||||
|
store %1, %2[%c0] : memref<1xindex>
|
||||||
|
%3 = memref_reshape %arg0(%2)
|
||||||
|
: (memref<*xf32>, memref<1xindex>) -> memref<?xf32>
|
||||||
|
%4 = dim %3, %c0 : memref<?xf32>
|
||||||
|
%5 = index_cast %4 : index to i64
|
||||||
|
// CHECK-LABEL: alloc
|
||||||
|
// CHECK-SAME: reuse_input_candidates = []
|
||||||
|
%6 = alloc() : memref<1xi64>
|
||||||
|
store %5, %6[%c0] : memref<1xi64>
|
||||||
|
%7 = load %6[%c0] : memref<1xi64>
|
||||||
|
%8 = index_cast %7 : i64 to index
|
||||||
|
// CHECK-LABEL: alloc
|
||||||
|
// CHECK-SAME: reuse_input_candidates = [0 : index]
|
||||||
|
%9 = alloc(%8) : memref<?xf32>
|
||||||
|
linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||||
|
iterator_types = ["parallel"]
|
||||||
|
} ins(%3 : memref<?xf32>) outs(%9 : memref<?xf32>) {
|
||||||
|
^bb0(%arg1: f32, %arg2: f32): // no predecessors
|
||||||
|
%12 = absf %arg1 : f32
|
||||||
|
linalg.yield %12 : f32
|
||||||
|
}
|
||||||
|
%10 = tensor_to_memref %0 : memref<?xindex>
|
||||||
|
%11 = memref_reshape %9(%10)
|
||||||
|
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||||
|
return %11 : memref<*xf32>
|
||||||
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: kernel-gen-opt %s --final-bufferize | FileCheck %s
|
// RUN: kernel-gen-opt %s --bufferize | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: @extract_element
|
// CHECK-LABEL: @extract_element
|
||||||
// CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> f32
|
// CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> f32
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --canonicalize --shape-to-descriptors --canonicalize --hlo-bufferize --std-bufferize --func-bufferize | FileCheck %s
|
// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --canonicalize --shape-to-descriptors --canonicalize --bufferize | FileCheck %s
|
||||||
|
|
||||||
// Test whether all shape computations required for isinf can be lowered to
|
// Test whether all shape computations required for isinf can be lowered to
|
||||||
// the standard dialect, scf and descriptors.
|
// the standard dialect, scf and descriptors.
|
||||||
|
@ -24,7 +24,7 @@ func @print_memrefs(
|
|||||||
return %output : memref<*xf16>
|
return %output : memref<*xf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: func private @print_memref_index(memref<*xindex>)
|
// CHECK: func private @print_memref_i64(memref<*xi64>)
|
||||||
|
|
||||||
// CHECK-LABEL: func @print_memrefs
|
// CHECK-LABEL: func @print_memrefs
|
||||||
|
|
||||||
@ -33,12 +33,16 @@ func @print_memrefs(
|
|||||||
// CHECK: [[NUM_ELEM:%.*]] = alloca() : memref<1xindex>
|
// CHECK: [[NUM_ELEM:%.*]] = alloca() : memref<1xindex>
|
||||||
// CHECK: store {{%.*}}, [[NUM_ELEM]]
|
// CHECK: store {{%.*}}, [[NUM_ELEM]]
|
||||||
|
|
||||||
// CHECK: [[UNRANKED_NUM_ELEM:%.*]] = memref_cast [[NUM_ELEM]]
|
// CHECK: [[NUM_ELEM_I64:%.*]] = index_cast [[NUM_ELEM]]
|
||||||
// CHECK-NEXT: call @print_memref_index([[UNRANKED_NUM_ELEM]])
|
// CHECK-SAME: : memref<1xindex> to memref<1xi64>
|
||||||
|
// CHECK-NEXT: [[UNRANKED_NUM_ELEM:%.*]] = memref_cast [[NUM_ELEM_I64]]
|
||||||
|
// CHECK-NEXT: call @print_memref_i64([[UNRANKED_NUM_ELEM]])
|
||||||
|
|
||||||
// CHECK: memref_reshape
|
// CHECK: memref_reshape
|
||||||
// CHECK: tf_framework.alloc
|
// CHECK: tf_framework.alloc
|
||||||
|
|
||||||
// CHECK: [[UNRANKED_SHAPE:%.*]] = memref_cast [[SHAPE]]
|
// CHECK: [[SHAPE_I64:%.*]] = index_cast [[SHAPE]]
|
||||||
// CHECK-NEXT: call @print_memref_index([[UNRANKED_SHAPE]])
|
// CHECK-SAME: : memref<?xindex> to memref<?xi64>
|
||||||
|
// CHECK-NEXT: [[UNRANKED_SHAPE:%.*]] = memref_cast [[SHAPE_I64]]
|
||||||
|
// CHECK-NEXT: call @print_memref_i64([[UNRANKED_SHAPE]])
|
||||||
// CHECK: memref_reshape
|
// CHECK: memref_reshape
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --shape-to-descriptors --canonicalize --hlo-bufferize --std-bufferize --scf-bufferize --func-bufferize | FileCheck %s
|
// RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --shape-to-descriptors --canonicalize --bufferize | FileCheck %s
|
||||||
|
|
||||||
// Test whether all shape computations required for tanh can be lowered to
|
// Test whether all shape computations required for tanh can be lowered to
|
||||||
// the standard dialect, scf and descriptors. We check for a sparse pattern here,
|
// the standard dialect, scf and descriptors. We check for a sparse pattern here,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --hlo-bufferize
|
// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --bufferize
|
||||||
|
|
||||||
func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
%0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
@ -75,7 +75,7 @@ extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) {
|
|||||||
extern "C" void _mlir_ciface_tf_report_error(void* op_kernel_ctx,
|
extern "C" void _mlir_ciface_tf_report_error(void* op_kernel_ctx,
|
||||||
int32_t error_code, char* msg) {
|
int32_t error_code, char* msg) {
|
||||||
Optional<ErrorCode> symbol = symbolizeErrorCode(error_code);
|
Optional<ErrorCode> symbol = symbolizeErrorCode(error_code);
|
||||||
if (symbol.hasValue()) {
|
if (!symbol.hasValue()) {
|
||||||
LOG(ERROR) << "No valid conversion from integer value = " << error_code
|
LOG(ERROR) << "No valid conversion from integer value = " << error_code
|
||||||
<< "to ErrorCode attribute";
|
<< "to ErrorCode attribute";
|
||||||
return;
|
return;
|
||||||
|
@ -55,12 +55,14 @@ namespace {
|
|||||||
/// A temporary buffer size analysis that is correct but may be incomplete.
|
/// A temporary buffer size analysis that is correct but may be incomplete.
|
||||||
class BufferSizeAnalysis {
|
class BufferSizeAnalysis {
|
||||||
public:
|
public:
|
||||||
explicit BufferSizeAnalysis(FuncOp f) { build(f); }
|
BufferSizeAnalysis(FuncOp f, const BufferAliasAnalysis &aliases) {
|
||||||
|
build(f, aliases);
|
||||||
|
}
|
||||||
|
|
||||||
bool is_same_size(Value a, Value b) { return ecs_.isEquivalent(a, b); }
|
bool is_same_size(Value a, Value b) { return ecs_.isEquivalent(a, b); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void build(FuncOp &f) {
|
void build(FuncOp &f, const BufferAliasAnalysis &aliases) {
|
||||||
auto buffers = find_buffer_values(f);
|
auto buffers = find_buffer_values(f);
|
||||||
|
|
||||||
// Memrefs with statically known same shape and same symbol-free affine maps
|
// Memrefs with statically known same shape and same symbol-free affine maps
|
||||||
@ -102,10 +104,16 @@ class BufferSizeAnalysis {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Operand and result of `reshape_memref_cast` must be of same size.
|
// All aliases of a memref must be of the same underlying buffer size.
|
||||||
f.walk([&](MemRefReshapeOp reshapeOp) {
|
for (auto e : aliases) {
|
||||||
ecs_.unionSets(reshapeOp.result(), reshapeOp.source());
|
Value value = e.getFirst();
|
||||||
});
|
if (!value.getType().isa<BaseMemRefType>()) continue;
|
||||||
|
for (Value alias : e.getSecond()) {
|
||||||
|
assert(alias.getType().isa<BaseMemRefType>() &&
|
||||||
|
"Expected aliases of memref to be memrefs.");
|
||||||
|
ecs_.unionSets(value, alias);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as,
|
bool affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as,
|
||||||
@ -181,7 +189,7 @@ class BufferReuseAnalysis {
|
|||||||
|
|
||||||
void find_reuse_candiates(FuncOp &f, BufferAliasAnalysis &aliases) {
|
void find_reuse_candiates(FuncOp &f, BufferAliasAnalysis &aliases) {
|
||||||
Liveness liveness(f);
|
Liveness liveness(f);
|
||||||
BufferSizeAnalysis size_equivalences(f);
|
BufferSizeAnalysis size_equivalences(f, aliases);
|
||||||
f.walk([&](Block *block) {
|
f.walk([&](Block *block) {
|
||||||
find_reuse_candiates(block, aliases, liveness.getLiveness(block),
|
find_reuse_candiates(block, aliases, liveness.getLiveness(block),
|
||||||
size_equivalences, f.getArguments());
|
size_equivalences, f.getArguments());
|
||||||
@ -204,50 +212,53 @@ class BufferReuseAnalysis {
|
|||||||
|
|
||||||
// Find reuse candidates for the regarded allocation.
|
// Find reuse candidates for the regarded allocation.
|
||||||
SmallVector<int64_t, 2> local_reuse_candidates;
|
SmallVector<int64_t, 2> local_reuse_candidates;
|
||||||
for (auto it : llvm::enumerate(arguments)) {
|
for (BlockArgument old_buffer : arguments) {
|
||||||
int64_t old_buffer_index = it.index();
|
|
||||||
Value old_buffer = it.value();
|
|
||||||
if (!old_buffer.getType().isa<BaseMemRefType>()) continue;
|
if (!old_buffer.getType().isa<BaseMemRefType>()) continue;
|
||||||
|
|
||||||
// Will not reuse buffers of different size as they may be too small.
|
// Size criterion: Do not reuse buffers of different size as they may be
|
||||||
|
// too small.
|
||||||
if (!size_equivalences.is_same_size(new_buffer, old_buffer)) continue;
|
if (!size_equivalences.is_same_size(new_buffer, old_buffer)) continue;
|
||||||
|
|
||||||
// Only reuse buffers that are no longer used on first reuse, i.e. they
|
// Lifetime criterion: Only reuse buffers that are no longer used on
|
||||||
// are no longer alive.
|
// first reuse, i.e. they are no longer alive.
|
||||||
bool livetimes_compatible = true;
|
bool lifetimes_compatible = true;
|
||||||
for (Value old_buffer_alias : aliases.resolve(old_buffer)) {
|
for (Value old_buffer_alias : aliases.resolve(old_buffer)) {
|
||||||
if (first_reuse == nullptr) {
|
if (first_reuse == nullptr) {
|
||||||
// If the first use is beyond the end of this block we look at the
|
// If the first use is beyond the end of this block we look at the
|
||||||
// block end. An argument buffer that is already reusable there is
|
// block end. An argument buffer that is already reusable there is
|
||||||
// certainly reusable at any later actual use.
|
// certainly reusable at any later actual use. Otherwise, lifetimes
|
||||||
|
// are incompatible.
|
||||||
if (liveness->isLiveOut(old_buffer_alias)) {
|
if (liveness->isLiveOut(old_buffer_alias)) {
|
||||||
livetimes_compatible = false;
|
lifetimes_compatible = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// A buffer is *not* reusable if
|
// A buffer is reusable if
|
||||||
// i) its last use is after the point of reuse, or
|
// i) its last use is before the point of reuse, or
|
||||||
// ii) its last use is also its first reuse but the operation
|
// ii) its last use is also its first reuse and the operation
|
||||||
// does not allow for local reuse.
|
// allows for local reuse.
|
||||||
|
// Otherwise, lifetimes are incompatible.
|
||||||
Operation *last_use =
|
Operation *last_use =
|
||||||
liveness->getEndOperation(old_buffer_alias, &block->front());
|
liveness->getEndOperation(old_buffer_alias, &block->front());
|
||||||
assert(last_use != nullptr && last_use->getBlock() == block &&
|
assert(last_use != nullptr && last_use->getBlock() == block &&
|
||||||
"Expected last use in same block.");
|
"Expected last use in same block.");
|
||||||
if (first_reuse->isBeforeInBlock(last_use)) {
|
if (first_reuse->isBeforeInBlock(last_use)) {
|
||||||
livetimes_compatible = false;
|
lifetimes_compatible = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (first_reuse == last_use &&
|
if (first_reuse == last_use &&
|
||||||
!can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) {
|
!can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) {
|
||||||
livetimes_compatible = false;
|
lifetimes_compatible = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// All criteria are fulfilled 🙂.
|
if (lifetimes_compatible) {
|
||||||
if (livetimes_compatible)
|
// All criteria are fulfilled 🙂.
|
||||||
|
int64_t old_buffer_index = old_buffer.getArgNumber();
|
||||||
local_reuse_candidates.push_back(old_buffer_index);
|
local_reuse_candidates.push_back(old_buffer_index);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reuse_candidates_[&op] = local_reuse_candidates;
|
reuse_candidates_[&op] = local_reuse_candidates;
|
||||||
|
@ -48,40 +48,6 @@ namespace kernel_gen {
|
|||||||
namespace transforms {
|
namespace transforms {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#define GEN_PASS_CLASSES
|
|
||||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
|
|
||||||
|
|
||||||
struct HloBufferizePass : public HloBufferizePassBase<HloBufferizePass> {
|
|
||||||
// TODO(b/173201243): Move to tablegen.
|
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
|
||||||
registry.insert<lmhlo::LmhloDialect>();
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
void runOnOperation() override {
|
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
auto& context = getContext();
|
|
||||||
ConversionTarget target(context);
|
|
||||||
target.addLegalDialect<lmhlo::LmhloDialect>();
|
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
|
||||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
|
||||||
|
|
||||||
BufferizeTypeConverter converter;
|
|
||||||
// Configure bufferize pattern for functions and lhlo.
|
|
||||||
mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
|
||||||
|
|
||||||
// Configure legality and structural patterns.
|
|
||||||
populateBufferizeMaterializationLegality(target);
|
|
||||||
populateShapeStructuralTypeConversionsAndLegality(&context, converter,
|
|
||||||
patterns, target);
|
|
||||||
scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
|
|
||||||
patterns, target);
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
|
||||||
std::move(patterns))))
|
|
||||||
signalPassFailure();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO(herhut) : This could become a real pattern in bufferize pass. What we
|
// TODO(herhut) : This could become a real pattern in bufferize pass. What we
|
||||||
// would need to do is insert a copy to model the semantics correctly. The same
|
// would need to do is insert a copy to model the semantics correctly. The same
|
||||||
// is true for the TensorLoad pattern that is already in there. Then buffer
|
// is true for the TensorLoad pattern that is already in there. Then buffer
|
||||||
@ -104,11 +70,34 @@ class UnrankedTensorStoreTestOnlyPattern
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
|
// TODO(frgossen): Move this upstream to `populateFuncOpTypeConversionPattern`
|
||||||
|
// This pattern is merely needed to materialize type casts for return values so
|
||||||
|
// that they match the function signature conversion.
|
||||||
|
class ReturnOpTypeConversionPattern
|
||||||
|
: public OpConversionPattern<mlir::ReturnOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<mlir::ReturnOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
ReturnOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
|
||||||
|
|
||||||
|
struct BufferizePass : public BufferizePassBase<BufferizePass> {
|
||||||
|
explicit BufferizePass(bool allow_partial_bufferization) {
|
||||||
|
allow_partial_bufferization_ = allow_partial_bufferization;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(b/173201243): Move to tablegen.
|
// TODO(b/173201243): Move to tablegen.
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<AffineDialect, scf::SCFDialect, shape::ShapeDialect,
|
registry.insert<AffineDialect, scf::SCFDialect, shape::ShapeDialect,
|
||||||
tf_framework::TFFrameworkDialect>();
|
tf_framework::TFFrameworkDialect, lmhlo::LmhloDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -117,12 +106,17 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
|
|||||||
ConversionTarget target(context);
|
ConversionTarget target(context);
|
||||||
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
|
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
|
||||||
tf_framework::TFFrameworkDialect, AffineDialect,
|
tf_framework::TFFrameworkDialect, AffineDialect,
|
||||||
shape::ShapeDialect>();
|
shape::ShapeDialect, lmhlo::LmhloDialect>();
|
||||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
|
|
||||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||||
target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
|
target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
|
||||||
TensorFromElementsOp, TensorCastOp, TensorLoadOp,
|
TensorFromElementsOp, TensorCastOp>();
|
||||||
TensorToMemrefOp>();
|
|
||||||
|
if (!allow_partial_bufferization_) {
|
||||||
|
target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
|
||||||
|
}
|
||||||
|
|
||||||
// Certain operations are no longer legal on tensors but otherwise are.
|
// Certain operations are no longer legal on tensors but otherwise are.
|
||||||
target.addDynamicallyLegalOp<ConstantOp, SelectOp>([&](Operation* op) {
|
target.addDynamicallyLegalOp<ConstantOp, SelectOp>([&](Operation* op) {
|
||||||
return llvm::none_of(op->getResultTypes(),
|
return llvm::none_of(op->getResultTypes(),
|
||||||
@ -144,8 +138,8 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
|
|||||||
return converter.isLegal(inputs) && converter.isLegal(results) &&
|
return converter.isLegal(inputs) && converter.isLegal(results) &&
|
||||||
converter.isLegal(&op.getBody());
|
converter.isLegal(&op.getBody());
|
||||||
});
|
});
|
||||||
target.addDynamicallyLegalOp<CallOp, ConstantOp, DimOp, RankOp, SelectOp>(
|
target.addDynamicallyLegalOp<CallOp, ConstantOp, DimOp, RankOp, SelectOp,
|
||||||
typesAreLegal);
|
ReturnOp>(typesAreLegal);
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
||||||
@ -160,22 +154,27 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
|
|||||||
scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
|
scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
|
||||||
patterns, target);
|
patterns, target);
|
||||||
patterns.insert<UnrankedTensorStoreTestOnlyPattern>(&context);
|
patterns.insert<UnrankedTensorStoreTestOnlyPattern>(&context);
|
||||||
|
patterns.insert<ReturnOpTypeConversionPattern>(converter, &context);
|
||||||
|
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
if (allow_partial_bufferization_) {
|
||||||
signalPassFailure();
|
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (failed(
|
||||||
|
mlir::applyFullConversion(module, target, std::move(patterns)))) {
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp> > CreateHloBufferizePass() {
|
std::unique_ptr<OperationPass<ModuleOp> > CreateBufferizePass(
|
||||||
return std::make_unique<HloBufferizePass>();
|
bool allow_partial_bufferization) {
|
||||||
}
|
return std::make_unique<BufferizePass>(allow_partial_bufferization);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp> > CreateFinalBufferizePass() {
|
|
||||||
return std::make_unique<FinalBufferizePass>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace transforms
|
} // namespace transforms
|
||||||
|
@ -37,9 +37,8 @@ namespace {
|
|||||||
|
|
||||||
using tf_framework::TFFrameworkDialect;
|
using tf_framework::TFFrameworkDialect;
|
||||||
|
|
||||||
Operation* emitCallToFunc(Location loc, StringRef func_name,
|
Operation* emitCallToPrint(Location loc, StringRef func_name, Value arg,
|
||||||
ArrayRef<Type> return_types, ValueRange args,
|
OpBuilder* b) {
|
||||||
OpBuilder* b) {
|
|
||||||
auto caller_func =
|
auto caller_func =
|
||||||
b->getInsertionBlock()->getParent()->getParentOfType<FuncOp>();
|
b->getInsertionBlock()->getParent()->getParentOfType<FuncOp>();
|
||||||
auto callee_func =
|
auto callee_func =
|
||||||
@ -49,12 +48,12 @@ Operation* emitCallToFunc(Location loc, StringRef func_name,
|
|||||||
|
|
||||||
auto module = caller_func.getParentOfType<ModuleOp>();
|
auto module = caller_func.getParentOfType<ModuleOp>();
|
||||||
b->setInsertionPointToStart(module.getBody());
|
b->setInsertionPointToStart(module.getBody());
|
||||||
auto func_type =
|
auto func_type = FunctionType::get(arg.getType(), /*results=*/llvm::None,
|
||||||
FunctionType::get(args.getTypes(), return_types, b->getContext());
|
b->getContext());
|
||||||
callee_func = b->create<FuncOp>(module.getLoc(), func_name, func_type);
|
callee_func = b->create<FuncOp>(module.getLoc(), func_name, func_type);
|
||||||
callee_func.setPrivate();
|
callee_func.setPrivate();
|
||||||
}
|
}
|
||||||
return b->create<CallOp>(loc, callee_func, args);
|
return b->create<CallOp>(loc, callee_func, arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) {
|
void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) {
|
||||||
@ -70,28 +69,32 @@ void EmitPrint(Operation* op, Liveness& liveness, OpBuilder* b) {
|
|||||||
liveness.getLiveness(op->getBlock())->getEndOperation(memref, op);
|
liveness.getLiveness(op->getBlock())->getEndOperation(memref, op);
|
||||||
b->setInsertionPoint(end_op);
|
b->setInsertionPoint(end_op);
|
||||||
|
|
||||||
|
if (element_type.isIndex()) {
|
||||||
|
element_type = b->getI64Type();
|
||||||
|
memref_type = MemRefType::get(memref_type.getShape(), element_type,
|
||||||
|
memref_type.getAffineMaps(),
|
||||||
|
memref_type.getMemorySpace());
|
||||||
|
memref = b->create<IndexCastOp>(loc, memref, memref_type);
|
||||||
|
}
|
||||||
|
|
||||||
auto unranked_type =
|
auto unranked_type =
|
||||||
UnrankedMemRefType::get(element_type, memref_type.getMemorySpace());
|
UnrankedMemRefType::get(element_type, memref_type.getMemorySpace());
|
||||||
Value unranked_memref = b->create<MemRefCastOp>(loc, memref, unranked_type);
|
Value unranked_memref = b->create<MemRefCastOp>(loc, memref, unranked_type);
|
||||||
|
|
||||||
if (element_type.isF32()) {
|
if (element_type.isF32()) {
|
||||||
emitCallToFunc(loc, "print_memref_f32", {}, {unranked_memref}, b);
|
emitCallToPrint(loc, "print_memref_f32", unranked_memref, b);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (element_type.isF64()) {
|
if (element_type.isF64()) {
|
||||||
emitCallToFunc(loc, "print_memref_f64", {}, {unranked_memref}, b);
|
emitCallToPrint(loc, "print_memref_f64", unranked_memref, b);
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (element_type.isIndex()) {
|
|
||||||
emitCallToFunc(loc, "print_memref_index", {}, {unranked_memref}, b);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (element_type.isInteger(32)) {
|
if (element_type.isInteger(32)) {
|
||||||
emitCallToFunc(loc, "print_memref_i32", {}, {unranked_memref}, b);
|
emitCallToPrint(loc, "print_memref_i32", unranked_memref, b);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (element_type.isInteger(64)) {
|
if (element_type.isInteger(64) || element_type.isIndex()) {
|
||||||
emitCallToFunc(loc, "print_memref_i64", {}, {unranked_memref}, b);
|
emitCallToPrint(loc, "print_memref_i64", unranked_memref, b);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,13 +47,10 @@ std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass();
|
|||||||
// using memref descriptors.
|
// using memref descriptors.
|
||||||
std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass();
|
std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass();
|
||||||
|
|
||||||
// Pass to tranform hlo-level computations on values to their corresponding
|
// Pass to tranform operations on values to their corresponding parts on
|
||||||
// parts on buffers.
|
// buffers.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> CreateHloBufferizePass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateBufferizePass(
|
||||||
|
bool allow_partial_bufferization = false);
|
||||||
// Pass to tranform late-dialect level computations (essentially all non-hlo
|
|
||||||
// dialects) on values to their corresponding parts on buffers.
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass();
|
|
||||||
|
|
||||||
// Pass to materialize broadcasts.
|
// Pass to materialize broadcasts.
|
||||||
std::unique_ptr<FunctionPass> CreateMaterializeBroadcastsPass();
|
std::unique_ptr<FunctionPass> CreateMaterializeBroadcastsPass();
|
||||||
|
@ -38,16 +38,14 @@ def ShapeToDescriptorsPass : Pass<"shape-to-descriptors", "ModuleOp"> {
|
|||||||
let constructor = "transforms::CreateShapeToDescriptorsPass()";
|
let constructor = "transforms::CreateShapeToDescriptorsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def HloBufferizePass : Pass<"hlo-bufferize", "ModuleOp"> {
|
def BufferizePass : Pass<"bufferize", "ModuleOp"> {
|
||||||
let summary = "Pass to transform hlo operations on values to buffer based "
|
let summary = "Pass to transform operations on values to buffer-based ones.";
|
||||||
"ones.";
|
let options = [
|
||||||
let constructor = "transforms::CreateHloBufferizePass()";
|
Option<"allow_partial_bufferization_", "allow-partial-bufferization",
|
||||||
}
|
"bool", /*default=*/"false", "Allow partial bufferization. "
|
||||||
|
"Value-based operations may remain, e.g. for shape operations.">,
|
||||||
def FinalBufferizePass : Pass<"final-bufferize", "ModuleOp"> {
|
];
|
||||||
let summary = "Pass to transform operations from all non-hlo dialects on "
|
let constructor = "transforms::CreateBufferizePass()";
|
||||||
"values to buffer-based ones.";
|
|
||||||
let constructor = "transforms::CreateFinalBufferizePass()";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def MaterializeBroadcastsPass : FunctionPass<"materialize-broadcast"> {
|
def MaterializeBroadcastsPass : FunctionPass<"materialize-broadcast"> {
|
||||||
|
@ -140,6 +140,7 @@ cc_library(
|
|||||||
":translate_cl_options",
|
":translate_cl_options",
|
||||||
"//tensorflow/compiler/mlir/hlo",
|
"//tensorflow/compiler/mlir/hlo",
|
||||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||||
|
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
||||||
"//tensorflow/compiler/xla:debug_options_flags",
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
@ -148,6 +149,8 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||||
|
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
|
||||||
|
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
@ -63,8 +63,11 @@ HloModule SelectAndScatter
|
|||||||
ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10)
|
ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: module
|
||||||
|
// CHECK: global_memref "private" constant @[[$GLOBAL:.*]] : memref<f32> = dense<0.000000e+00>
|
||||||
// CHECK-LABEL: func @main
|
// CHECK-LABEL: func @main
|
||||||
// CHECK: "lmhlo.select_and_scatter"
|
// CHECK: %[[GLOBAL_MEMREF:.*]] = get_global_memref @[[$GLOBAL]] : memref<f32>
|
||||||
|
// CHECK: "lmhlo.select_and_scatter"(%{{.*}}, %{{.*}}, %[[GLOBAL_MEMREF]], %{{.*}})
|
||||||
// CHECK: ^bb0(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG0]], %[[ARG1]]) {comparison_direction = "GE"}
|
// CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG0]], %[[ARG1]]) {comparison_direction = "GE"}
|
||||||
// CHECK: "mhlo.return"(%[[COMPARE]]) : (tensor<i1>) -> ()
|
// CHECK: "mhlo.return"(%[[COMPARE]]) : (tensor<i1>) -> ()
|
||||||
@ -78,7 +81,7 @@ HloModule SelectAndScatter
|
|||||||
ENTRY main () -> f32[6] {
|
ENTRY main () -> f32[6] {
|
||||||
%operand = f32[6]{0} parameter(0)
|
%operand = f32[6]{0} parameter(0)
|
||||||
%source = f32[2]{0} parameter(1)
|
%source = f32[2]{0} parameter(1)
|
||||||
%init = f32[] parameter(2)
|
%init = f32[] constant(0)
|
||||||
ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32
|
ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,4 +104,20 @@ ENTRY main {
|
|||||||
s32[] %static),
|
s32[] %static),
|
||||||
custom_call_target="SliceToDynamic",
|
custom_call_target="SliceToDynamic",
|
||||||
backend_config=""
|
backend_config=""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
|
||||||
|
HloModule Cholesky
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @main
|
||||||
|
// CHECK: "lmhlo_gpu.cholesky"
|
||||||
|
// CHECK-SAME: is_lower = true
|
||||||
|
ENTRY main {
|
||||||
|
%param = f32[3,3] parameter(0)
|
||||||
|
ROOT %custom-call = (f32[3,3], f32[3], s32[]) custom-call(f32[3,3] %param),
|
||||||
|
custom_call_target="__cusolver$cholesky",
|
||||||
|
operand_layout_constraints={f32[3,3]},
|
||||||
|
backend_config="{\"lower\":true}"
|
||||||
|
}
|
||||||
|
@ -374,3 +374,23 @@ func @main(%arg0: tuple<tuple<tensor<f32>>, tensor<f32>>, %arg1: tuple<tensor<f3
|
|||||||
|
|
||||||
return %result : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
return %result : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @main
|
||||||
|
// CHECK: "lmhlo.reduce"({{.*}}) ( {
|
||||||
|
// CHECK: ^bb0(%[[VAL1:.*]]: tensor<f32>, %[[VAL2:.*]]: tensor<i32>, %[[VAL3:.*]]: tensor<f32>, %[[VAL4:.*]]: tensor<i32>): // no predecessors
|
||||||
|
// CHECK: %[[VAL5:.*]] = mhlo.maximum %[[VAL1]], %[[VAL3]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL6:.*]] = mhlo.maximum %[[VAL2]], %[[VAL4:.*]] : tensor<i32>
|
||||||
|
// CHECK: %[[VAL7:.*]] = "mhlo.tuple"(%[[VAL5]], %[[VAL6:.*]]) : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||||
|
// CHECK: "mhlo.return"(%[[VAL7:.*]]) : (tuple<tensor<f32>, tensor<i32>>) -> ()
|
||||||
|
// CHECK: })
|
||||||
|
func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
|
||||||
|
%result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
|
||||||
|
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
|
||||||
|
%fmax = "mhlo.maximum"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
%imax = "mhlo.maximum"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||||
|
"mhlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
|
||||||
|
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
|
||||||
|
return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
|
||||||
|
}
|
||||||
|
@ -13,10 +13,8 @@
|
|||||||
// CHECK-LABEL: func @add
|
// CHECK-LABEL: func @add
|
||||||
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||||
// CHECK-NEXT: %[[SUM0:.*]] = mhlo.add %arg0, %arg0 : tensor<2xi32>
|
// CHECK-NEXT: %[[SUM0:.*]] = mhlo.add %arg0, %arg0 : tensor<2xi32>
|
||||||
// CHECK-NEXT: %[[SUM1:.*]] = mhlo.add %[[SUM0]], %arg0 : tensor<2xi32>
|
// CHECK-NEXT: return %[[SUM0]] : tensor<2xi32>
|
||||||
// CHECK-NEXT: return %[[SUM1]] : tensor<2xi32>
|
%1 = "tf.AddV2"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||||
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
|
||||||
%1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
|
||||||
return %1: tensor<2xi32>
|
return %1: tensor<2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,7 +25,7 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
|||||||
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||||
// CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
|
// CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
|
||||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||||
return %0: tensor<1x2xi32>
|
return %0: tensor<1x2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,7 +35,7 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
|
|||||||
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
|
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
|
||||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
|
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
|
||||||
// CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
|
// CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
|
||||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
||||||
return %0: tensor<4x4x4x4xi32>
|
return %0: tensor<4x4x4x4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,7 +53,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
|
|||||||
// CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
// CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||||
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32>
|
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32>
|
||||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
return %0: tensor<?x?xi32>
|
return %0: tensor<?x?xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,7 +62,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
|
|||||||
func @broadcast_add_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
func @broadcast_add_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||||
// CHECK: tf.Add
|
// CHECK: tf.Add
|
||||||
// CHLO: chlo.broadcast_add %arg0, %arg1
|
// CHLO: chlo.broadcast_add %arg0, %arg1
|
||||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
|
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||||
return %0: tensor<*xi32>
|
return %0: tensor<*xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,6 +262,13 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1>
|
|||||||
return %0: tensor<*xi1>
|
return %0: tensor<*xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @equal_unsupported_type
|
||||||
|
func @equal_unsupported_type(%arg0: tensor<*x!tf.string>, %arg1: tensor<*x!tf.string>) -> tensor<*xi1> {
|
||||||
|
// CHECK: "tf.Equal"
|
||||||
|
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*x!tf.string>, tensor<*x!tf.string>) -> tensor<*xi1>
|
||||||
|
return %0: tensor<*xi1>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @notequal
|
// CHECK-LABEL: func @notequal
|
||||||
func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"}
|
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"}
|
||||||
|
@ -26,7 +26,7 @@ func @tf_unknown_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||||
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
%0 = "tf.AddV2"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||||
return %0: tensor<2xi32>
|
return %0: tensor<2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ func @tf_unknown_known_mix(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
|||||||
// expected-error@+1 {{'tf.OpA' op is not legalizable}}
|
// expected-error@+1 {{'tf.OpA' op is not legalizable}}
|
||||||
%0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
%0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||||
%1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
%1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||||
%2 = "tf.Add"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
%2 = "tf.AddV2"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||||
%3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
%3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||||
return %2: tensor<2xi32>
|
return %2: tensor<2xi32>
|
||||||
}
|
}
|
||||||
|
@ -3984,31 +3984,35 @@ func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) {
|
|||||||
// tf.Unpack legalization
|
// tf.Unpack legalization
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// TODO(b/156340000): Re-enable when fixed.
|
// CHECK-LABEL: @unpack
|
||||||
// // C-HECK-LABEL: @unpack
|
func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) {
|
||||||
// func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) {
|
// CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
|
||||||
// // C-HECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
|
// CHECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
|
||||||
// // C-HECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32>
|
// CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
|
||||||
// // C-HECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
|
// CHECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
|
||||||
// // C-HECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
|
// CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
|
||||||
// // C-HECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
|
// CHECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
|
||||||
// // C-HECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
|
|
||||||
|
|
||||||
// %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
|
%0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
|
||||||
// // return %[[RES1]], %[[RES2]], %[[RES3]]
|
// return %[[RES1]], %[[RES2]], %[[RES3]]
|
||||||
// return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>
|
return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>
|
||||||
// }
|
}
|
||||||
|
|
||||||
// // C-HECK-LABEL: @unpack_dynamic
|
// CHECK-LABEL: @unpack_dynamic
|
||||||
// func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||||
// // C-HECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32>
|
|
||||||
// // C-HECK: "mhlo.reshape"(%[[SLICE1]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32>
|
|
||||||
// // C-HECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32>
|
|
||||||
// // C-HECK: "mhlo.reshape"(%[[SLICE2]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32>
|
|
||||||
|
|
||||||
// %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
|
// CHECK: tf.Unpack
|
||||||
// return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
|
%0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
|
||||||
// }
|
return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @unpack_unranked
|
||||||
|
func @unpack_unranked(%input: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||||
|
|
||||||
|
// CHECK: tf.Unpack
|
||||||
|
%0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
|
||||||
|
return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization
|
// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization
|
||||||
|
@ -3311,7 +3311,7 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
|
|||||||
auto input_val = GetScalarConstOfType(begin_element_ty, loc,
|
auto input_val = GetScalarConstOfType(begin_element_ty, loc,
|
||||||
input_shape[d], &rewriter);
|
input_shape[d], &rewriter);
|
||||||
auto wrapped_index =
|
auto wrapped_index =
|
||||||
rewriter.create<TF::AddOp>(loc, input_val, reshaped_index);
|
rewriter.create<TF::AddV2Op>(loc, input_val, reshaped_index);
|
||||||
auto final_index = rewriter.create<SelectOp>(
|
auto final_index = rewriter.create<SelectOp>(
|
||||||
loc, type, index_negative, wrapped_index, reshaped_index);
|
loc, type, index_negative, wrapped_index, reshaped_index);
|
||||||
slice_begin_indices.push_back(final_index);
|
slice_begin_indices.push_back(final_index);
|
||||||
@ -4808,7 +4808,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(TF::UnpackOp op,
|
LogicalResult matchAndRewrite(TF::UnpackOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto value_type = op.value().getType().cast<RankedTensorType>();
|
auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!value_type) return failure();
|
if (!value_type) return failure();
|
||||||
|
|
||||||
int64_t value_rank = value_type.getRank();
|
int64_t value_rank = value_type.getRank();
|
||||||
@ -4820,7 +4820,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
|
|||||||
auto end_indices = llvm::to_vector<4>(value_type.getShape());
|
auto end_indices = llvm::to_vector<4>(value_type.getShape());
|
||||||
SmallVector<int64_t, 4> strides(value_rank, 1);
|
SmallVector<int64_t, 4> strides(value_rank, 1);
|
||||||
|
|
||||||
// All HLO slice+reshape results used to replace the original tf.Unpack op.
|
// All HLO slice+squeeze results used to replace the original tf.Unpack op.
|
||||||
SmallVector<Value, 4> results;
|
SmallVector<Value, 4> results;
|
||||||
results.reserve(op.getNumResults());
|
results.reserve(op.getNumResults());
|
||||||
|
|
||||||
@ -4833,9 +4833,10 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
|
|||||||
GetI64ElementsAttr(end_indices, &rewriter),
|
GetI64ElementsAttr(end_indices, &rewriter),
|
||||||
GetI64ElementsAttr(strides, &rewriter));
|
GetI64ElementsAttr(strides, &rewriter));
|
||||||
// Reshape to drop the axis dimension.
|
// Reshape to drop the axis dimension.
|
||||||
auto reshape_op = rewriter.create<mhlo::ReshapeOp>(
|
auto result =
|
||||||
op.getLoc(), op.getType(i), slice_op);
|
rewriter.create<TF::SqueezeOp>(op.getLoc(), op.getType(i), slice_op,
|
||||||
results.push_back(reshape_op);
|
rewriter.getI64ArrayAttr(op.axis()));
|
||||||
|
results.push_back(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(op, results);
|
rewriter.replaceOp(op, results);
|
||||||
|
@ -89,8 +89,7 @@ class DirectBinaryPat<Op FromOp, Op ToOp>
|
|||||||
: Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
|
: Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
|
||||||
(ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
|
(ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
|
||||||
|
|
||||||
foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp],
|
foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp],
|
||||||
[TF_AddV2Op, HLOClient_BroadcastAddOp],
|
|
||||||
[TF_DivOp, HLOClient_BroadcastDivOp],
|
[TF_DivOp, HLOClient_BroadcastDivOp],
|
||||||
[TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp],
|
[TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp],
|
||||||
[TF_MaximumOp, HLOClient_BroadcastMaxOp],
|
[TF_MaximumOp, HLOClient_BroadcastMaxOp],
|
||||||
@ -225,7 +224,7 @@ class EqualityPat<Op FromOp, StrEnumAttrCase direction>
|
|||||||
(HLOClient_BroadcastCompareOp
|
(HLOClient_BroadcastCompareOp
|
||||||
$l, $r, (BinBroadcastDimensions $l, $r), direction,
|
$l, $r, (BinBroadcastDimensions $l, $r), direction,
|
||||||
(HLO_DEFAULT_COMPARISON_TYPE)),
|
(HLO_DEFAULT_COMPARISON_TYPE)),
|
||||||
[(AreBroadcastCompatible $l, $r)]>;
|
[(AreBroadcastCompatible $l, $r), (HLO_Tensor $l)]>;
|
||||||
|
|
||||||
def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>;
|
def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>;
|
||||||
def : EqualityPat<TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE>;
|
def : EqualityPat<TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE>;
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
|
||||||
|
|
||||||
|
#include <climits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
@ -38,6 +39,7 @@ limitations under the License.
|
|||||||
#include "mlir/Pass/PassOptions.h" // from @llvm-project
|
#include "mlir/Pass/PassOptions.h" // from @llvm-project
|
||||||
#include "mlir/Translation.h" // from @llvm-project
|
#include "mlir/Translation.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||||
@ -46,18 +48,21 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
#include "tensorflow/compiler/xla/service/backend.h"
|
#include "tensorflow/compiler/xla/service/backend.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
using xla::BufferAllocation;
|
using xla::BufferAllocation;
|
||||||
using xla::BufferAssignment;
|
using xla::BufferAssignment;
|
||||||
using xla::HloComputation;
|
using xla::HloComputation;
|
||||||
|
using xla::HloCustomCallInstruction;
|
||||||
using xla::HloInstruction;
|
using xla::HloInstruction;
|
||||||
using xla::HloModule;
|
using xla::HloModule;
|
||||||
using xla::HloModuleProto;
|
using xla::HloModuleProto;
|
||||||
@ -140,8 +145,9 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
|
|||||||
class XlaHloToLhloPass
|
class XlaHloToLhloPass
|
||||||
: public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
|
: public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
|
registry
|
||||||
mlir::lmhlo::LmhloDialect>();
|
.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
|
||||||
|
mlir::lmhlo::LmhloDialect, mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -274,6 +280,10 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
|
|||||||
return EmitSelectAndScatterOp(instr);
|
return EmitSelectAndScatterOp(instr);
|
||||||
case HloOpcode::kCustomCall:
|
case HloOpcode::kCustomCall:
|
||||||
return EmitCustomCallOp(instr);
|
return EmitCustomCallOp(instr);
|
||||||
|
case HloOpcode::kConstant:
|
||||||
|
return EmitConstant(instr);
|
||||||
|
case HloOpcode::kReduce:
|
||||||
|
return EmitReduceOp(instr);
|
||||||
default:
|
default:
|
||||||
llvm::errs() << instr->ToString();
|
llvm::errs() << instr->ToString();
|
||||||
return tensorflow::errors::Internal(
|
return tensorflow::errors::Internal(
|
||||||
@ -485,13 +495,18 @@ StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
|
|||||||
return select_and_scatter;
|
return select_and_scatter;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<lmhlo::CustomCallOp> LhloDialectEmitter::EmitCustomCallOp(
|
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
|
||||||
HloInstruction* instr) {
|
HloInstruction* instr) {
|
||||||
|
auto* custom_call_instr = ::xla::Cast<::xla::HloCustomCallInstruction>(instr);
|
||||||
|
|
||||||
|
if (xla::gpu::IsCustomCallToCusolver(*instr)) {
|
||||||
|
return EmitCholesky(custom_call_instr);
|
||||||
|
}
|
||||||
|
|
||||||
size_t num_arguments, num_results;
|
size_t num_arguments, num_results;
|
||||||
TF_ASSIGN_OR_RETURN(auto custom_call,
|
TF_ASSIGN_OR_RETURN(auto custom_call,
|
||||||
CreateOpWithoutAttrs<lmhlo::CustomCallOp>(
|
CreateOpWithoutAttrs<lmhlo::CustomCallOp>(
|
||||||
instr, num_arguments, num_results));
|
instr, num_arguments, num_results));
|
||||||
auto* custom_call_instr = ::xla::Cast<::xla::HloCustomCallInstruction>(instr);
|
|
||||||
custom_call.call_target_nameAttr(
|
custom_call.call_target_nameAttr(
|
||||||
builder_.getStringAttr(custom_call_instr->custom_call_target()));
|
builder_.getStringAttr(custom_call_instr->custom_call_target()));
|
||||||
custom_call.backend_configAttr(
|
custom_call.backend_configAttr(
|
||||||
@ -500,12 +515,102 @@ StatusOr<lmhlo::CustomCallOp> LhloDialectEmitter::EmitCustomCallOp(
|
|||||||
static_cast<int32_t>(num_results)};
|
static_cast<int32_t>(num_results)};
|
||||||
custom_call.setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
|
custom_call.setAttr(lmhlo::CustomCallOp::getOperandSegmentSizeAttr(),
|
||||||
builder_.getI32VectorAttr(segments));
|
builder_.getI32VectorAttr(segments));
|
||||||
return custom_call;
|
return custom_call.getOperation();
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
|
||||||
|
HloCustomCallInstruction* custom_call) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto cholesky_op,
|
||||||
|
CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
|
||||||
|
custom_call->backend_config<xla::CholeskyOptions>());
|
||||||
|
cholesky_op.is_lowerAttr(builder_.getBoolAttr(options.lower()));
|
||||||
|
return cholesky_op;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert an XLA HLO constant to a global_memref + get_global_memref pair.
|
||||||
|
StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant(
|
||||||
|
const HloInstruction* instr) {
|
||||||
|
// Insert a global_memref in the module.
|
||||||
|
Location loc = getLocation(instr);
|
||||||
|
|
||||||
|
auto const_instr = ::xla::Cast<::xla::HloConstantInstruction>(instr);
|
||||||
|
TF_RET_CHECK(const_instr->shape().IsArray() &&
|
||||||
|
const_instr->shape().is_static());
|
||||||
|
TF_ASSIGN_OR_RETURN(Type type, ::xla::ConvertShapeToType<MemRefType>(
|
||||||
|
const_instr->shape(), builder_));
|
||||||
|
auto memref_type = type.dyn_cast<MemRefType>();
|
||||||
|
TF_RET_CHECK(memref_type != nullptr);
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
DenseElementsAttr initial_value,
|
||||||
|
CreateDenseElementsAttrFromLiteral(const_instr->literal(), builder_));
|
||||||
|
|
||||||
|
std::string constant_name = ::xla::llvm_ir::ConstantHloToGlobalName(*instr);
|
||||||
|
|
||||||
|
// Insert the global memref at the top level.
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(builder_);
|
||||||
|
builder_.clearInsertionPoint();
|
||||||
|
auto global_var = builder_.create<GlobalMemrefOp>(
|
||||||
|
loc, constant_name, builder_.getStringAttr("private"),
|
||||||
|
TypeAttr::get(memref_type), initial_value, true);
|
||||||
|
SymbolTable(module_).insert(global_var);
|
||||||
|
global_var.getOperation()->moveBefore(&module_.front());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_global_memref =
|
||||||
|
builder_.create<GetGlobalMemrefOp>(loc, memref_type, constant_name);
|
||||||
|
|
||||||
|
// For operations that do not fold this constant value in their codegen, we
|
||||||
|
// still need to materialize it into a buffer. Since buffer allocation is
|
||||||
|
// already done, annotate the get_global_memref with the information to get to
|
||||||
|
// the allocated buffer slice for this constant if need be.
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
|
||||||
|
assignment_.GetUniqueTopLevelSlice(instr));
|
||||||
|
get_global_memref.setAttr("lmhlo.alloc",
|
||||||
|
builder_.getIndexAttr(slice.index()));
|
||||||
|
get_global_memref.setAttr("lmhlo.slice_offset",
|
||||||
|
builder_.getI64IntegerAttr(slice.offset()));
|
||||||
|
get_global_memref.setAttr("lmhlo.slice_size",
|
||||||
|
builder_.getI64IntegerAttr(slice.size()));
|
||||||
|
|
||||||
|
// Update the cache to remember this value.
|
||||||
|
auto& cached_value = slices_[std::make_pair(instr, ::xla::ShapeIndex())];
|
||||||
|
TF_RET_CHECK(cached_value == nullptr);
|
||||||
|
cached_value = get_global_memref;
|
||||||
|
return get_global_memref;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp(
|
||||||
|
HloInstruction* instr) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto reduce_op,
|
||||||
|
CreateOpWithoutAttrs<lmhlo::ReduceOp>(instr));
|
||||||
|
auto* reduce = ::xla::Cast<::xla::HloReduceInstruction>(instr);
|
||||||
|
std::vector<int64_t> dimensions(reduce->dimensions().begin(),
|
||||||
|
reduce->dimensions().end());
|
||||||
|
reduce_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
|
||||||
|
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
|
||||||
|
*instr->called_computations()[0], &reduce_op.body(), &builder_));
|
||||||
|
return reduce_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
||||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||||
const ::xla::ShapeIndex& shape_index) {
|
const ::xla::ShapeIndex& shape_index) {
|
||||||
|
// Cache generated ViewOp and StaticMemRefCastOp by (instruction,
|
||||||
|
// shape_index).
|
||||||
|
auto& cached_value = slices_[std::make_pair(instr, shape_index)];
|
||||||
|
if (cached_value) {
|
||||||
|
return cached_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (instr->IsConstant() && shape_index.empty()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(Value constant_memref, EmitConstant(instr));
|
||||||
|
return cached_value = constant_memref;
|
||||||
|
}
|
||||||
|
|
||||||
// If the shape happens to have dynamic dimensions, create the memref using
|
// If the shape happens to have dynamic dimensions, create the memref using
|
||||||
// the underlying static shape.
|
// the underlying static shape.
|
||||||
// TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
|
// TODO(jurahul): Revisit this when we can model memrefs with dynamic shape
|
||||||
@ -518,7 +623,7 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
|||||||
assignment_.GetUniqueSlice(instr, shape_index));
|
assignment_.GetUniqueSlice(instr, shape_index));
|
||||||
Value alloc = allocations_[slice.allocation()];
|
Value alloc = allocations_[slice.allocation()];
|
||||||
if (alloc.getType() == out_type && slice.offset() == 0) {
|
if (alloc.getType() == out_type && slice.offset() == 0) {
|
||||||
return alloc;
|
return cached_value = alloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out_memref_type = out_type.dyn_cast<MemRefType>();
|
auto out_memref_type = out_type.dyn_cast<MemRefType>();
|
||||||
@ -527,13 +632,6 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
|||||||
"Expected memref type when creating a view for leaf type of a "
|
"Expected memref type when creating a view for leaf type of a "
|
||||||
"tuple.");
|
"tuple.");
|
||||||
|
|
||||||
// Cache generated ViewOp and StaticMemRefCastOp by (instruction,
|
|
||||||
// shape_index).
|
|
||||||
auto& cached_value = slices_[std::make_pair(instr, shape_index)];
|
|
||||||
if (cached_value) {
|
|
||||||
return cached_value;
|
|
||||||
}
|
|
||||||
|
|
||||||
Value byte_shift =
|
Value byte_shift =
|
||||||
builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset());
|
builder_.create<ConstantIndexOp>(alloc.getLoc(), slice.offset());
|
||||||
|
|
||||||
@ -695,8 +793,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
|
|||||||
Status HloToLhloModule(const BufferAssignment& assignment,
|
Status HloToLhloModule(const BufferAssignment& assignment,
|
||||||
const HloModule& hlo_module, ModuleOp module) {
|
const HloModule& hlo_module, ModuleOp module) {
|
||||||
module.getContext()
|
module.getContext()
|
||||||
->loadDialect<StandardOpsDialect, mhlo::MhloDialect,
|
->loadDialect<StandardOpsDialect, mhlo::MhloDialect, lmhlo::LmhloDialect,
|
||||||
lmhlo::LmhloDialect>();
|
lmhlo_gpu::LmhloGpuDialect>();
|
||||||
HloComputation* computation = hlo_module.entry_computation();
|
HloComputation* computation = hlo_module.entry_computation();
|
||||||
|
|
||||||
LhloDialectEmitter emitter(assignment, *computation, module);
|
LhloDialectEmitter emitter(assignment, *computation, module);
|
||||||
|
@ -16,13 +16,15 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
|
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
@ -55,8 +57,18 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
|||||||
::xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(::xla::HloInstruction* instr);
|
::xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(::xla::HloInstruction* instr);
|
||||||
::xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp(
|
::xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp(
|
||||||
::xla::HloInstruction* instr);
|
::xla::HloInstruction* instr);
|
||||||
::xla::StatusOr<lmhlo::CustomCallOp> EmitCustomCallOp(
|
|
||||||
::xla::HloInstruction* instr);
|
::xla::StatusOr<Operation*> EmitCustomCallOp(::xla::HloInstruction* instr);
|
||||||
|
::xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky(
|
||||||
|
::xla::HloCustomCallInstruction* custom_call);
|
||||||
|
|
||||||
|
::xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(::xla::HloInstruction* instr);
|
||||||
|
::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
|
||||||
|
::xla::HloInstruction* instr) {
|
||||||
|
return EmitConstant(static_cast<const ::xla::HloInstruction*>(instr));
|
||||||
|
}
|
||||||
|
::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
|
||||||
|
const ::xla::HloInstruction* instr);
|
||||||
|
|
||||||
::xla::Status CreateOperands(::xla::HloInstruction* instr,
|
::xla::Status CreateOperands(::xla::HloInstruction* instr,
|
||||||
SmallVectorImpl<Value>& operands,
|
SmallVectorImpl<Value>& operands,
|
||||||
@ -122,7 +134,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
|||||||
OpBuilder* b, Location loc);
|
OpBuilder* b, Location loc);
|
||||||
|
|
||||||
// Return an MLIR location for an HLO instruction.
|
// Return an MLIR location for an HLO instruction.
|
||||||
Location getLocation(::xla::HloInstruction* inst) {
|
Location getLocation(const ::xla::HloInstruction* inst) {
|
||||||
return NameLoc::get(builder_.getIdentifier(inst->name()),
|
return NameLoc::get(builder_.getIdentifier(inst->name()),
|
||||||
builder_.getContext());
|
builder_.getContext());
|
||||||
}
|
}
|
||||||
@ -180,7 +192,7 @@ tensorflow::Status HloToLhloModule(const ::xla::BufferAssignment& assignment,
|
|||||||
ModuleOp module);
|
ModuleOp module);
|
||||||
|
|
||||||
OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
|
OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
|
||||||
mlir::MLIRContext* context);
|
MLIRContext* context);
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
@ -85,6 +85,9 @@ namespace tensorflow {
|
|||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
namespace convert {
|
namespace convert {
|
||||||
|
|
||||||
|
using absl::StrAppend;
|
||||||
|
using absl::StrCat;
|
||||||
|
|
||||||
bool IsEngineInput(absl::string_view name) {
|
bool IsEngineInput(absl::string_view name) {
|
||||||
return absl::StartsWith(name, IONamePrefixes::kInputPHName);
|
return absl::StartsWith(name, IONamePrefixes::kInputPHName);
|
||||||
}
|
}
|
||||||
@ -92,47 +95,6 @@ bool IsEngineOutput(absl::string_view name) {
|
|||||||
return absl::StartsWith(name, IONamePrefixes::kOutputPHName);
|
return absl::StartsWith(name, IONamePrefixes::kOutputPHName);
|
||||||
}
|
}
|
||||||
|
|
||||||
using absl::StrAppend;
|
|
||||||
using absl::StrCat;
|
|
||||||
|
|
||||||
inline Status TfDataTypeToTrt(DataType tf_dtype,
|
|
||||||
nvinfer1::DataType* trt_dtype) {
|
|
||||||
switch (tf_dtype) {
|
|
||||||
case DataType::DT_FLOAT:
|
|
||||||
*trt_dtype = nvinfer1::DataType::kFLOAT;
|
|
||||||
break;
|
|
||||||
case DataType::DT_HALF:
|
|
||||||
*trt_dtype = nvinfer1::DataType::kHALF;
|
|
||||||
break;
|
|
||||||
case DataType::DT_INT32:
|
|
||||||
*trt_dtype = nvinfer1::DataType::kINT32;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return errors::InvalidArgument("Unsupported data type ",
|
|
||||||
DataTypeString(tf_dtype));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Status TrtDataTypeToTf(nvinfer1::DataType trt_dtype,
|
|
||||||
DataType* tf_dtype) {
|
|
||||||
switch (trt_dtype) {
|
|
||||||
case nvinfer1::DataType::kFLOAT:
|
|
||||||
*tf_dtype = DataType::DT_FLOAT;
|
|
||||||
break;
|
|
||||||
case nvinfer1::DataType::kHALF:
|
|
||||||
*tf_dtype = DataType::DT_HALF;
|
|
||||||
break;
|
|
||||||
case nvinfer1::DataType::kINT32:
|
|
||||||
*tf_dtype = DataType::DT_INT32;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return errors::InvalidArgument("Unsupported data type ",
|
|
||||||
DebugString(trt_dtype));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
class TFAttrs {
|
class TFAttrs {
|
||||||
public:
|
public:
|
||||||
explicit TFAttrs(const NodeDef& tf_node) {
|
explicit TFAttrs(const NodeDef& tf_node) {
|
||||||
@ -182,7 +144,7 @@ std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
|
|||||||
template <>
|
template <>
|
||||||
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
|
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
|
||||||
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
|
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
|
||||||
TF_CHECK_OK(TfDataTypeToTrt(this->at(key)->type(), &trt_dtype));
|
TF_CHECK_OK(TfTypeToTrtType(this->at(key)->type(), &trt_dtype));
|
||||||
return trt_dtype;
|
return trt_dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,7 +233,7 @@ Status ValidateTensorProperties(const string& producer_node_type,
|
|||||||
nvinfer1::DataType* trt_dtype,
|
nvinfer1::DataType* trt_dtype,
|
||||||
nvinfer1::Dims* trt_dims, int* batch_size) {
|
nvinfer1::Dims* trt_dims, int* batch_size) {
|
||||||
// Convert data type.
|
// Convert data type.
|
||||||
TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, trt_dtype));
|
TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, trt_dtype));
|
||||||
|
|
||||||
// Convert shape.
|
// Convert shape.
|
||||||
if (shape.dims() < 0) {
|
if (shape.dims() < 0) {
|
||||||
@ -512,7 +474,7 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
|
|||||||
TFAttrs attrs(params->node_def);
|
TFAttrs attrs(params->node_def);
|
||||||
if (attrs.count(dtype_attr_name)) {
|
if (attrs.count(dtype_attr_name)) {
|
||||||
DataType dtype = attrs.get<DataType>(dtype_attr_name);
|
DataType dtype = attrs.get<DataType>(dtype_attr_name);
|
||||||
TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_type));
|
TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, &trt_type));
|
||||||
}
|
}
|
||||||
|
|
||||||
// In order to be broadcastable, the number of dims has to match.
|
// In order to be broadcastable, the number of dims has to match.
|
||||||
@ -1091,7 +1053,7 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype,
|
|||||||
DataType tf_dtype;
|
DataType tf_dtype;
|
||||||
// TODO(laigd): make it return a status.
|
// TODO(laigd): make it return a status.
|
||||||
TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape));
|
TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape));
|
||||||
TF_CHECK_OK(TrtDataTypeToTf(trt_dtype, &tf_dtype));
|
TF_CHECK_OK(TrtTypeToTfType(trt_dtype, &tf_dtype));
|
||||||
// TODO(jie): check weights size_bytes. 0 means type error
|
// TODO(jie): check weights size_bytes. 0 means type error
|
||||||
Tensor tensor(tf_dtype, shape);
|
Tensor tensor(tf_dtype, shape);
|
||||||
TRT_ShapedWeights weights(trt_dtype, dims, tensor);
|
TRT_ShapedWeights weights(trt_dtype, dims, tensor);
|
||||||
@ -2621,6 +2583,7 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input,
|
|||||||
nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation(
|
nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation(
|
||||||
const_cast<nvinfer1::ITensor* const*>(concat_inputs.data()),
|
const_cast<nvinfer1::ITensor* const*>(concat_inputs.data()),
|
||||||
concat_inputs.size());
|
concat_inputs.size());
|
||||||
|
SetLayerName(concat_layer, params->node_def, "concat", op_instance);
|
||||||
concat_layer->setAxis(0);
|
concat_layer->setAxis(0);
|
||||||
nvinfer1::ITensor* new_shape = concat_layer->getOutput(0);
|
nvinfer1::ITensor* new_shape = concat_layer->getOutput(0);
|
||||||
// Reshape input using new shape
|
// Reshape input using new shape
|
||||||
@ -4219,7 +4182,7 @@ Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
|
|||||||
|
|
||||||
// Verify that the dtype is supported by TensorRT. Otherwise, return an error.
|
// Verify that the dtype is supported by TensorRT. Otherwise, return an error.
|
||||||
nvinfer1::DataType trt_dtype;
|
nvinfer1::DataType trt_dtype;
|
||||||
TF_RETURN_IF_ERROR(TfDataTypeToTrt(converted_dtype, &trt_dtype));
|
TF_RETURN_IF_ERROR(TfTypeToTrtType(converted_dtype, &trt_dtype));
|
||||||
|
|
||||||
if (tensor.NumElements() == 0) {
|
if (tensor.NumElements() == 0) {
|
||||||
// Return empty weights.
|
// Return empty weights.
|
||||||
@ -6244,7 +6207,7 @@ Status ConvertGraphDefToEngine(
|
|||||||
TFAttrs attrs(node_def);
|
TFAttrs attrs(node_def);
|
||||||
DataType tf_dtype = attrs.get<DataType>("T");
|
DataType tf_dtype = attrs.get<DataType>("T");
|
||||||
nvinfer1::DataType trt_dtype;
|
nvinfer1::DataType trt_dtype;
|
||||||
TF_RETURN_IF_ERROR(TfDataTypeToTrt(tf_dtype, &trt_dtype));
|
TF_RETURN_IF_ERROR(TfTypeToTrtType(tf_dtype, &trt_dtype));
|
||||||
if (output_tensors.size() <= slot_number) {
|
if (output_tensors.size() <= slot_number) {
|
||||||
output_tensors.resize(slot_number + 1);
|
output_tensors.resize(slot_number + 1);
|
||||||
}
|
}
|
||||||
|
@ -135,20 +135,6 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) {
|
|
||||||
nvinfer1::DataType trt_type;
|
|
||||||
Status status = TfTypeToTrtType(tf_type, &trt_type);
|
|
||||||
EXPECT_EQ(status, Status::OK());
|
|
||||||
return trt_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) {
|
|
||||||
DataType tf_type;
|
|
||||||
Status status = TrtTypeToTfType(trt_type, &tf_type);
|
|
||||||
EXPECT_EQ(status, Status::OK());
|
|
||||||
return tf_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
NodeDef MakeNodeDef(const string& name, const string& op,
|
NodeDef MakeNodeDef(const string& name, const string& op,
|
||||||
const std::vector<string>& inputs,
|
const std::vector<string>& inputs,
|
||||||
const std::map<string, AttrValue> attrs = {}) {
|
const std::map<string, AttrValue> attrs = {}) {
|
||||||
@ -1048,8 +1034,10 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
|
void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
|
||||||
TRT_ShapedWeights weights = weight_store->GetTempWeights(
|
nvinfer1::DataType trt_type;
|
||||||
TfDataTypeToTrt(DataTypeToEnum<T>::v()), GetTestDims({2, 3}));
|
TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &trt_type));
|
||||||
|
TRT_ShapedWeights weights =
|
||||||
|
weight_store->GetTempWeights(trt_type, GetTestDims({2, 3}));
|
||||||
const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
|
const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
|
||||||
memcpy(weights.GetValues(), values.data(), weights.size_bytes());
|
memcpy(weights.GetValues(), values.data(), weights.size_bytes());
|
||||||
|
|
||||||
@ -1445,7 +1433,8 @@ class OpConverterTest : public ::testing::Test {
|
|||||||
ASSERT_NE(-1, input_index);
|
ASSERT_NE(-1, input_index);
|
||||||
const nvinfer1::DataType trt_dtype =
|
const nvinfer1::DataType trt_dtype =
|
||||||
engine_->getBindingDataType(input_index);
|
engine_->getBindingDataType(input_index);
|
||||||
const DataType tf_type = TrtDataTypeToTf(trt_dtype);
|
DataType tf_type;
|
||||||
|
TF_ASSERT_OK(TrtTypeToTfType(trt_dtype, &tf_type));
|
||||||
ASSERT_EQ(data.tensor.dtype(), tf_type)
|
ASSERT_EQ(data.tensor.dtype(), tf_type)
|
||||||
<< DataTypeString(data.tensor.dtype()) << " vs. "
|
<< DataTypeString(data.tensor.dtype()) << " vs. "
|
||||||
<< DataTypeString(tf_type);
|
<< DataTypeString(tf_type);
|
||||||
@ -1457,8 +1446,9 @@ class OpConverterTest : public ::testing::Test {
|
|||||||
// Mark the output tensor as TRT engine output.
|
// Mark the output tensor as TRT engine output.
|
||||||
std::vector<Converter::EngineOutputInfo> output_info;
|
std::vector<Converter::EngineOutputInfo> output_info;
|
||||||
for (const auto& data : *output_data) {
|
for (const auto& data : *output_data) {
|
||||||
output_info.push_back(
|
nvinfer1::DataType trt_type;
|
||||||
{data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())});
|
TF_RETURN_IF_ERROR(TfTypeToTrtType(data.tensor.dtype(), &trt_type));
|
||||||
|
output_info.push_back({data.name, data.name, trt_type});
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(converter_->RenameAndMarkOutputTensors(output_info));
|
TF_RETURN_IF_ERROR(converter_->RenameAndMarkOutputTensors(output_info));
|
||||||
|
|
||||||
@ -1519,7 +1509,8 @@ class OpConverterTest : public ::testing::Test {
|
|||||||
const string& name, const std::vector<int32>& dims,
|
const string& name, const std::vector<int32>& dims,
|
||||||
nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT,
|
nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT,
|
||||||
Status add_input_status = Status::OK()) {
|
Status add_input_status = Status::OK()) {
|
||||||
DataType tf_type = TrtDataTypeToTf(trt_type);
|
DataType tf_type;
|
||||||
|
TF_ASSERT_OK(TrtTypeToTfType(trt_type, &tf_type));
|
||||||
ops::Placeholder::Attrs attrs;
|
ops::Placeholder::Attrs attrs;
|
||||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_));
|
TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_));
|
||||||
|
|
||||||
@ -1566,7 +1557,8 @@ class OpConverterTest : public ::testing::Test {
|
|||||||
node_inputs_[name] = ops::Const(scope_.WithOpName(name), t);
|
node_inputs_[name] = ops::Const(scope_.WithOpName(name), t);
|
||||||
|
|
||||||
// Add weights for conversion.
|
// Add weights for conversion.
|
||||||
const nvinfer1::DataType dtype = TfDataTypeToTrt(DataTypeToEnum<T>::v());
|
nvinfer1::DataType dtype;
|
||||||
|
TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &dtype));
|
||||||
const nvinfer1::Dims trt_dims = GetTestDims(dims);
|
const nvinfer1::Dims trt_dims = GetTestDims(dims);
|
||||||
const int64_t num_elements = TrtWeightDimsNumElements(trt_dims);
|
const int64_t num_elements = TrtWeightDimsNumElements(trt_dims);
|
||||||
QCHECK_EQ(num_elements, values.size())
|
QCHECK_EQ(num_elements, values.size())
|
||||||
@ -1800,8 +1792,9 @@ class ParameterizedOpConverterTestBase
|
|||||||
partial_shape = dims;
|
partial_shape = dims;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AddTestTensorWithTFDims(name, partial_shape, TfDataTypeToTrt(tf_type),
|
nvinfer1::DataType trt_type;
|
||||||
add_input_status);
|
TF_ASSERT_OK(TfTypeToTrtType(tf_type, &trt_type));
|
||||||
|
AddTestTensorWithTFDims(name, partial_shape, trt_type, add_input_status);
|
||||||
if (!values.empty()) {
|
if (!values.empty()) {
|
||||||
VLOG(2) << "Adding test tensor: " << name << " "
|
VLOG(2) << "Adding test tensor: " << name << " "
|
||||||
<< DataTypeString(tf_type);
|
<< DataTypeString(tf_type);
|
||||||
@ -2032,7 +2025,7 @@ TEST_F(OpConverterTest, ConvertConst) {
|
|||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def = MakeConstNodeDef<double>("my_const", {});
|
NodeDef node_def = MakeConstNodeDef<double>("my_const", {});
|
||||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||||
"Unsupported data type double");
|
"Unsupported tensorflow data type double");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
Reset();
|
Reset();
|
||||||
@ -2805,8 +2798,9 @@ void TestAddN(OpConverterTest* test) {
|
|||||||
test->Reset();
|
test->Reset();
|
||||||
DataVec input_data;
|
DataVec input_data;
|
||||||
for (const auto name : {"inp1", "inp2", "inp3"}) {
|
for (const auto name : {"inp1", "inp2", "inp3"}) {
|
||||||
test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/2,
|
nvinfer1::DataType trt_type;
|
||||||
TfDataTypeToTrt(dtype));
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
|
test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/2, trt_type);
|
||||||
input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2),
|
input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2),
|
||||||
CType(3), CType(4)})});
|
CType(3), CType(4)})});
|
||||||
}
|
}
|
||||||
@ -2828,8 +2822,9 @@ void TestAddN(OpConverterTest* test) {
|
|||||||
test->Reset();
|
test->Reset();
|
||||||
DataVec input_data;
|
DataVec input_data;
|
||||||
for (const auto name : {"inp1", "inp2"}) {
|
for (const auto name : {"inp1", "inp2"}) {
|
||||||
test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/1,
|
nvinfer1::DataType trt_type;
|
||||||
TfDataTypeToTrt(dtype));
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
|
test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/1, trt_type);
|
||||||
input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2)})});
|
input_data.push_back({name, test->AsTensor<CType>({CType(1), CType(2)})});
|
||||||
}
|
}
|
||||||
test->AddTestWeights("inp3", /*dims=*/{1, 1, 2},
|
test->AddTestWeights("inp3", /*dims=*/{1, 1, 2},
|
||||||
@ -4252,8 +4247,9 @@ TEST_P(OpConverterTest1, ConvertConv2D) {
|
|||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def = get_conv2d_nodedef();
|
NodeDef node_def = get_conv2d_nodedef();
|
||||||
// Channel dim unknown, should fail.
|
// Channel dim unknown, should fail.
|
||||||
AddTestTensorWithTFDims("input", {-1, -1, -1, -1},
|
nvinfer1::DataType trt_type;
|
||||||
TfDataTypeToTrt(tf_type_));
|
TF_ASSERT_OK(TfTypeToTrtType(tf_type_, &trt_type));
|
||||||
|
AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, trt_type);
|
||||||
AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1});
|
AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1});
|
||||||
RunValidationAndConversion(
|
RunValidationAndConversion(
|
||||||
node_def, error::INVALID_ARGUMENT,
|
node_def, error::INVALID_ARGUMENT,
|
||||||
@ -5018,8 +5014,9 @@ TEST_F(OpConverterTest, ConvertTopK) {
|
|||||||
{
|
{
|
||||||
// K is a tensor, should fail.
|
// K is a tensor, should fail.
|
||||||
Reset();
|
Reset();
|
||||||
AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1,
|
nvinfer1::DataType trt_type;
|
||||||
/*trt_dtype=*/TfDataTypeToTrt(dtype));
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
|
AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1, trt_type);
|
||||||
AddTestTensor("weights", {2});
|
AddTestTensor("weights", {2});
|
||||||
RunValidationAndConversion(
|
RunValidationAndConversion(
|
||||||
node_def, error::UNIMPLEMENTED,
|
node_def, error::UNIMPLEMENTED,
|
||||||
@ -5590,8 +5587,10 @@ void TestConvertConcat(OpConverterTest* test) {
|
|||||||
NodeDef node_def = get_concat_nodedef(dtype, num_inputs);
|
NodeDef node_def = get_concat_nodedef(dtype, num_inputs);
|
||||||
// Create inputs.
|
// Create inputs.
|
||||||
for (int j = 0; j < num_inputs; ++j) {
|
for (int j = 0; j < num_inputs; ++j) {
|
||||||
|
nvinfer1::DataType trt_type;
|
||||||
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
test->AddTestTensor(StrCat("values_", j), ok_params[i].input_shapes[j], 1,
|
test->AddTestTensor(StrCat("values_", j), ok_params[i].input_shapes[j], 1,
|
||||||
TfDataTypeToTrt(dtype));
|
trt_type);
|
||||||
}
|
}
|
||||||
test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
|
test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
|
||||||
test->RunValidationAndConversion(node_def);
|
test->RunValidationAndConversion(node_def);
|
||||||
@ -5752,8 +5751,9 @@ void TestConvertSplit(OpConverterTest* test) {
|
|||||||
NodeDef node_def = get_split_nodedef(dtype, ok_params[i].num_split);
|
NodeDef node_def = get_split_nodedef(dtype, ok_params[i].num_split);
|
||||||
// Create inputs.
|
// Create inputs.
|
||||||
test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
|
test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
|
||||||
test->AddTestTensor("value", ok_params[i].input_shape, 1,
|
nvinfer1::DataType trt_type;
|
||||||
TfDataTypeToTrt(dtype));
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
|
test->AddTestTensor("value", ok_params[i].input_shape, 1, trt_type);
|
||||||
// Convert.
|
// Convert.
|
||||||
test->RunValidationAndConversion(node_def);
|
test->RunValidationAndConversion(node_def);
|
||||||
|
|
||||||
@ -5929,8 +5929,9 @@ void TestConvertUnpack(OpConverterTest* test) {
|
|||||||
NodeDef node_def =
|
NodeDef node_def =
|
||||||
get_unpack_nodedef(dtype, ok_params[i].num, ok_params[i].axis);
|
get_unpack_nodedef(dtype, ok_params[i].num, ok_params[i].axis);
|
||||||
// Create inputs.
|
// Create inputs.
|
||||||
test->AddTestTensor("value", ok_params[i].input_shape, 1,
|
nvinfer1::DataType trt_type;
|
||||||
TfDataTypeToTrt(dtype));
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
|
test->AddTestTensor("value", ok_params[i].input_shape, 1, trt_type);
|
||||||
// Convert.
|
// Convert.
|
||||||
test->RunValidationAndConversion(node_def);
|
test->RunValidationAndConversion(node_def);
|
||||||
|
|
||||||
@ -6272,8 +6273,10 @@ void TestConvertArgMinMax(OpConverterTest* test) {
|
|||||||
|
|
||||||
NodeDef node_def = GetArgMinMaxNodeDef<OpType>(dtype, DT_INT32);
|
NodeDef node_def = GetArgMinMaxNodeDef<OpType>(dtype, DT_INT32);
|
||||||
// Create inputs.
|
// Create inputs.
|
||||||
|
nvinfer1::DataType trt_type;
|
||||||
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
test->AddTestTensor("input", params[i].input_shape, /*batch_size=*/1,
|
test->AddTestTensor("input", params[i].input_shape, /*batch_size=*/1,
|
||||||
/*trt_dtype=*/TfDataTypeToTrt(dtype));
|
/*trt_dtype=*/trt_type);
|
||||||
test->AddTestWeights<int32>("dimension", {1}, {params[i].axis});
|
test->AddTestWeights<int32>("dimension", {1}, {params[i].axis});
|
||||||
test->RunValidationAndConversion(node_def);
|
test->RunValidationAndConversion(node_def);
|
||||||
|
|
||||||
@ -6374,8 +6377,9 @@ void TestConvertDepthSpaceShuffle(
|
|||||||
|
|
||||||
NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>(
|
NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>(
|
||||||
dtype, params[i].block_size, params[i].data_format);
|
dtype, params[i].block_size, params[i].data_format);
|
||||||
test->AddTestTensor("input", params[i].input_dims, 1,
|
nvinfer1::DataType trt_type;
|
||||||
TfDataTypeToTrt(dtype));
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
|
test->AddTestTensor("input", params[i].input_dims, 1, trt_type);
|
||||||
test->RunValidationAndConversion(node_def);
|
test->RunValidationAndConversion(node_def);
|
||||||
|
|
||||||
TRT_TensorOrWeights output;
|
TRT_TensorOrWeights output;
|
||||||
@ -6648,7 +6652,9 @@ void TestConvertClipByValue(OpConverterTest* test) {
|
|||||||
test->Reset();
|
test->Reset();
|
||||||
|
|
||||||
NodeDef node_def = GetClipByValueNodeDef(dtype);
|
NodeDef node_def = GetClipByValueNodeDef(dtype);
|
||||||
test->AddTestTensor("t", params[i].dims, 1, TfDataTypeToTrt(dtype));
|
nvinfer1::DataType trt_type;
|
||||||
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
|
test->AddTestTensor("t", params[i].dims, 1, trt_type);
|
||||||
test->AddTestWeights<CType>("clip_value_min", {1},
|
test->AddTestWeights<CType>("clip_value_min", {1},
|
||||||
{params[i].clip_value_min});
|
{params[i].clip_value_min});
|
||||||
test->AddTestWeights<CType>("clip_value_max", {1},
|
test->AddTestWeights<CType>("clip_value_max", {1},
|
||||||
@ -6848,9 +6854,11 @@ void TestConvertResize(OpConverterTest* test) {
|
|||||||
// Create resize node.
|
// Create resize node.
|
||||||
NodeDef node_def =
|
NodeDef node_def =
|
||||||
MakeResizeNodeDef<OpType>("my_resize", dtype, params[i].align_corners);
|
MakeResizeNodeDef<OpType>("my_resize", dtype, params[i].align_corners);
|
||||||
|
nvinfer1::DataType trt_type;
|
||||||
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
// Create input tensor
|
// Create input tensor
|
||||||
test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1,
|
test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1,
|
||||||
/*trt_dtype=*/TfDataTypeToTrt(dtype));
|
/*trt_dtype=*/trt_type);
|
||||||
// Create output size.
|
// Create output size.
|
||||||
test->AddTestWeights<int32>("size", {2}, params[i].output_resize_dims);
|
test->AddTestWeights<int32>("size", {2}, params[i].output_resize_dims);
|
||||||
|
|
||||||
@ -6949,8 +6957,10 @@ void TestConvertPad(OpConverterTest* test) {
|
|||||||
// Create pad node.
|
// Create pad node.
|
||||||
NodeDef node_def = MakePadNodeDef("my_pad", dtype);
|
NodeDef node_def = MakePadNodeDef("my_pad", dtype);
|
||||||
// Create input tensor
|
// Create input tensor
|
||||||
|
nvinfer1::DataType trt_type;
|
||||||
|
TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
|
||||||
test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1,
|
test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1,
|
||||||
/*trt_dtype=*/TfDataTypeToTrt(dtype));
|
/*trt_dtype=*/trt_type);
|
||||||
// Create output size.
|
// Create output size.
|
||||||
test->AddTestWeights<int32>("padding", params[i].pad_dims,
|
test->AddTestWeights<int32>("padding", params[i].pad_dims,
|
||||||
{0, 0, 1, 0, 0, 1, 0, 0});
|
{0, 0, 1, 0, 0, 1, 0, 0});
|
||||||
|
@ -198,7 +198,8 @@ Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
|
|||||||
*trt_type = nvinfer1::DataType::kINT32;
|
*trt_type = nvinfer1::DataType::kINT32;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return errors::Internal("Unsupported tensorflow type");
|
return errors::InvalidArgument("Unsupported tensorflow data type ",
|
||||||
|
DataTypeString(tf_type));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -215,7 +216,7 @@ Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
|
|||||||
*tf_type = DT_INT32;
|
*tf_type = DT_INT32;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return errors::Internal("Invalid TRT type");
|
return errors::InvalidArgument("Invalid TRT data type");
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -103,13 +103,16 @@ class TRTEngineOp : public AsyncOpKernel {
|
|||||||
TRTEngineCacheResource* cache_res,
|
TRTEngineCacheResource* cache_res,
|
||||||
AsyncHelper* helper);
|
AsyncHelper* helper);
|
||||||
|
|
||||||
// Construct a function handle for executing native funcdef graph
|
// Constructs a function handle for the segment of the TRTEngineOp.
|
||||||
// These are the exact same function.
|
StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle(
|
||||||
|
FunctionLibraryRuntime* lib, const string& device_name,
|
||||||
|
bool allow_soft_placement = false, size_t num_inputs = 0,
|
||||||
|
size_t num_outputs = 0);
|
||||||
|
|
||||||
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
// Imports the GraphDef for the segment of the TRTEngineOp to
|
||||||
const string& device_name,
|
// segment_graph_def_.
|
||||||
bool allow_soft_placement = false,
|
Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
|
||||||
size_t num_inputs = 0, size_t num_outputs = 0);
|
const string& device_name);
|
||||||
|
|
||||||
// Executes replaced native segment as function Op.
|
// Executes replaced native segment as function Op.
|
||||||
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
||||||
@ -175,12 +178,16 @@ class TRTEngineOp : public AsyncOpKernel {
|
|||||||
// Whether to build TensorRT engines at runtime.
|
// Whether to build TensorRT engines at runtime.
|
||||||
bool allow_build_at_runtime_;
|
bool allow_build_at_runtime_;
|
||||||
|
|
||||||
|
// Whether to allow soft placement when the graph is executed with native
|
||||||
|
// TensorFlow.
|
||||||
|
bool allow_soft_placement_;
|
||||||
|
|
||||||
// Maximum number of cached engines.
|
// Maximum number of cached engines.
|
||||||
int max_cached_engines_;
|
int max_cached_engines_;
|
||||||
|
|
||||||
int64 workspace_size_;
|
int64 workspace_size_;
|
||||||
mutex engine_mutex_;
|
mutex engine_mutex_;
|
||||||
FunctionLibraryRuntime::Handle func_handle_;
|
FunctionLibraryRuntime::Handle native_execution_func_handle_;
|
||||||
|
|
||||||
// The finalized calibrator for inference.
|
// The finalized calibrator for inference.
|
||||||
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
||||||
@ -260,11 +267,9 @@ static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle(
|
||||||
const string& device_name,
|
FunctionLibraryRuntime* lib, const string& device_name,
|
||||||
bool allow_soft_placement,
|
bool allow_soft_placement, size_t num_inputs, size_t num_outputs) {
|
||||||
size_t num_inputs,
|
|
||||||
size_t num_outputs) {
|
|
||||||
VLOG(1) << "Constructing function handle";
|
VLOG(1) << "Constructing function handle";
|
||||||
if (lib == nullptr) {
|
if (lib == nullptr) {
|
||||||
return errors::Internal("Context function library is null");
|
return errors::Internal("Context function library is null");
|
||||||
@ -298,8 +303,20 @@ Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
|||||||
inst_ops.config_proto.set_allow_soft_placement(true);
|
inst_ops.config_proto.set_allow_soft_placement(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
|
FunctionLibraryRuntime::Handle func_handle;
|
||||||
&func_handle_);
|
Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
|
||||||
|
inst_ops, &func_handle);
|
||||||
|
if (status.ok()) {
|
||||||
|
return func_handle;
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
|
||||||
|
const string& device_name) {
|
||||||
|
TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle,
|
||||||
|
ConstructFunctionHandle(lib, device_name));
|
||||||
|
return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_);
|
||||||
}
|
}
|
||||||
|
|
||||||
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||||
@ -335,14 +352,21 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
|||||||
<< context->device()->name()
|
<< context->device()->name()
|
||||||
<< ", thus setting _allow_build_at_runtime=true";
|
<< ", thus setting _allow_build_at_runtime=true";
|
||||||
allow_build_at_runtime_ = true;
|
allow_build_at_runtime_ = true;
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(context, status);
|
||||||
}
|
}
|
||||||
func_handle_ = kInvalidHandle;
|
|
||||||
|
status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_);
|
||||||
|
if (status.code() == tensorflow::error::NOT_FOUND) {
|
||||||
|
allow_soft_placement_ = true;
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(context, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
native_execution_func_handle_ = kInvalidHandle;
|
||||||
if (!static_engine_) {
|
if (!static_engine_) {
|
||||||
FunctionLibraryRuntime* lib = context->function_library();
|
OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(),
|
||||||
OP_REQUIRES_OK(context,
|
context->device()->name()));
|
||||||
ConstructFunctionHandle(lib, context->device()->name()));
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
context, FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_));
|
|
||||||
}
|
}
|
||||||
// TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
|
// TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
|
||||||
// backward compatibility reasons. Remove it once all known users switch to
|
// backward compatibility reasons. Remove it once all known users switch to
|
||||||
@ -411,13 +435,13 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
|||||||
AsyncHelper* helper) {
|
AsyncHelper* helper) {
|
||||||
std::vector<Tensor> inputs;
|
std::vector<Tensor> inputs;
|
||||||
std::vector<Tensor>* outputs = new std::vector<Tensor>();
|
std::vector<Tensor>* outputs = new std::vector<Tensor>();
|
||||||
if (func_handle_ == kInvalidHandle) {
|
if (native_execution_func_handle_ == kInvalidHandle) {
|
||||||
OP_REQUIRES_OK_ASYNC(
|
StatusOr<FunctionLibraryRuntime::Handle> status_or_handle =
|
||||||
ctx,
|
|
||||||
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
|
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
|
||||||
/*allow_soft_placement=*/true,
|
allow_soft_placement_, ctx->num_inputs(),
|
||||||
ctx->num_inputs(), ctx->num_outputs()),
|
ctx->num_outputs());
|
||||||
*helper);
|
OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), *helper);
|
||||||
|
native_execution_func_handle_ = status_or_handle.ValueOrDie();
|
||||||
}
|
}
|
||||||
auto lib = ctx->function_library();
|
auto lib = ctx->function_library();
|
||||||
FunctionLibraryRuntime::Options opts;
|
FunctionLibraryRuntime::Options opts;
|
||||||
@ -430,7 +454,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
|||||||
}
|
}
|
||||||
helper->Ref(); // Increment count for calculating native graph
|
helper->Ref(); // Increment count for calculating native graph
|
||||||
VLOG(1) << "Executing native segment: " << name();
|
VLOG(1) << "Executing native segment: " << name();
|
||||||
lib->Run(opts, func_handle_, inputs, outputs,
|
lib->Run(opts, native_execution_func_handle_, inputs, outputs,
|
||||||
[this, ctx, outputs, helper](const Status& s) {
|
[this, ctx, outputs, helper](const Status& s) {
|
||||||
core::ScopedUnref sc(helper);
|
core::ScopedUnref sc(helper);
|
||||||
OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
|
OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
|
||||||
@ -854,12 +878,8 @@ StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
|
|||||||
return std::pair<EngineContext*, int>(&empty_context, 0);
|
return std::pair<EngineContext*, int>(&empty_context, 0);
|
||||||
}
|
}
|
||||||
if (segment_graph_def_.node().empty()) {
|
if (segment_graph_def_.node().empty()) {
|
||||||
FunctionLibraryRuntime* lib = ctx->function_library();
|
Status status = ImportSegmentGraphDef(ctx->function_library(),
|
||||||
auto status = ConstructFunctionHandle(lib, ctx->device()->name());
|
ctx->device()->name());
|
||||||
if (status.ok()) {
|
|
||||||
status =
|
|
||||||
FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_);
|
|
||||||
}
|
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for "
|
LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Getting segment graph for "
|
||||||
<< name() << " failed. "
|
<< name() << " failed. "
|
||||||
|
@ -91,6 +91,13 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
|||||||
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
||||||
NameAttrList function;
|
NameAttrList function;
|
||||||
function.set_name(StrCat(op_name, "_native_segment"));
|
function.set_name(StrCat(op_name, "_native_segment"));
|
||||||
|
// We disable allow_soft_placement when executing the native segment of the
|
||||||
|
// TRTEngineOp for the following reasons:
|
||||||
|
// OpsTestBase only allow one device in the device manager.
|
||||||
|
// We need to define the GPU device to test TRTEngineOp.
|
||||||
|
// When allow_soft_placement is true, the TensorFlow runtime produces an
|
||||||
|
// error if a CPU device is not defined
|
||||||
|
// (see ProcessFunctionLibraryRuntime::InstantiateMultiDevice).
|
||||||
TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp")
|
TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp")
|
||||||
.Input(FakeInput(1, dtype))
|
.Input(FakeInput(1, dtype))
|
||||||
.Attr("input_shapes", {shape})
|
.Attr("input_shapes", {shape})
|
||||||
@ -105,6 +112,7 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
|||||||
.Attr("use_calibration", false)
|
.Attr("use_calibration", false)
|
||||||
.Attr("_use_implicit_batch", use_implicit_batch)
|
.Attr("_use_implicit_batch", use_implicit_batch)
|
||||||
.Attr("_allow_build_at_runtime", allow_build_at_runtime)
|
.Attr("_allow_build_at_runtime", allow_build_at_runtime)
|
||||||
|
.Attr("_allow_soft_placement", false)
|
||||||
.Attr("OutT", {dtype})
|
.Attr("OutT", {dtype})
|
||||||
.Finalize(OpsTestBase::node_def()));
|
.Finalize(OpsTestBase::node_def()));
|
||||||
TF_ASSERT_OK(InitOpWithFunctionLibrary());
|
TF_ASSERT_OK(InitOpWithFunctionLibrary());
|
||||||
|
@ -52,7 +52,11 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
|
|||||||
|
|
||||||
bool IsEnabled(const ConfigProto& config_proto,
|
bool IsEnabled(const ConfigProto& config_proto,
|
||||||
const Graph& graph) const override {
|
const Graph& graph) const override {
|
||||||
return IsMlirBridgePassEnabled(graph, config_proto);
|
// Do not run the bridge if it's enabled by the graph analysis,
|
||||||
|
// only run if it's enabled by the user explicitly.
|
||||||
|
MlirBridgeRolloutPolicy policy =
|
||||||
|
GetMlirBridgeRolloutPolicy(graph, config_proto);
|
||||||
|
return policy == MlirBridgeRolloutPolicy::kEnabledByUser;
|
||||||
}
|
}
|
||||||
|
|
||||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||||
|
@ -361,6 +361,7 @@ tf_cc_test(
|
|||||||
":util",
|
":util",
|
||||||
":xla_data_proto_cc",
|
":xla_data_proto_cc",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
@ -234,6 +234,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:shape_inference",
|
"//tensorflow/compiler/xla/service:shape_inference",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
@ -59,7 +59,6 @@ cc_library(
|
|||||||
srcs = ["comparators.cc"],
|
srcs = ["comparators.cc"],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"comparators.h",
|
"comparators.h",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":constants",
|
":constants",
|
||||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
@ -46,6 +47,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -1266,7 +1268,8 @@ StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
|
|||||||
}
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
|
XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||||
|
|
||||||
@ -1278,15 +1281,17 @@ XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::DotGeneral(XlaOp lhs, XlaOp rhs,
|
XlaOp XlaBuilder::DotGeneral(
|
||||||
const DotDimensionNumbers& dimension_numbers,
|
XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
TF_ASSIGN_OR_RETURN(
|
||||||
ShapeInference::InferDotOpShape(*lhs_shape, *rhs_shape,
|
Shape shape,
|
||||||
dimension_numbers));
|
ShapeInference::InferDotOpShape(
|
||||||
|
*lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
|
||||||
return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
|
return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
|
||||||
precision_config);
|
precision_config);
|
||||||
});
|
});
|
||||||
@ -1353,28 +1358,33 @@ Status XlaBuilder::VerifyConvolution(
|
|||||||
XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
|
XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> window_strides, Padding padding,
|
absl::Span<const int64> window_strides, Padding padding,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return ConvWithGeneralDimensions(
|
return ConvWithGeneralDimensions(
|
||||||
lhs, rhs, window_strides, padding,
|
lhs, rhs, window_strides, padding,
|
||||||
CreateDefaultConvDimensionNumbers(window_strides.size()),
|
CreateDefaultConvDimensionNumbers(window_strides.size()),
|
||||||
feature_group_count, batch_group_count, precision_config);
|
feature_group_count, batch_group_count, precision_config,
|
||||||
|
preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::ConvWithGeneralPadding(
|
XlaOp XlaBuilder::ConvWithGeneralPadding(
|
||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return ConvGeneral(lhs, rhs, window_strides, padding,
|
return ConvGeneral(lhs, rhs, window_strides, padding,
|
||||||
CreateDefaultConvDimensionNumbers(window_strides.size()),
|
CreateDefaultConvDimensionNumbers(window_strides.size()),
|
||||||
feature_group_count, batch_group_count, precision_config);
|
feature_group_count, batch_group_count, precision_config,
|
||||||
|
preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::ConvWithGeneralDimensions(
|
XlaOp XlaBuilder::ConvWithGeneralDimensions(
|
||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||||
@ -1402,7 +1412,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
|
|||||||
MakePadding(base_area_dimensions, window_dimensions,
|
MakePadding(base_area_dimensions, window_dimensions,
|
||||||
window_strides, padding),
|
window_strides, padding),
|
||||||
dimension_numbers, feature_group_count,
|
dimension_numbers, feature_group_count,
|
||||||
batch_group_count, precision_config);
|
batch_group_count, precision_config,
|
||||||
|
preferred_element_type);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1411,10 +1422,12 @@ XlaOp XlaBuilder::ConvGeneral(
|
|||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
|
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
|
||||||
dimension_numbers, feature_group_count,
|
dimension_numbers, feature_group_count,
|
||||||
batch_group_count, precision_config);
|
batch_group_count, precision_config,
|
||||||
|
preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::ConvGeneralDilated(
|
XlaOp XlaBuilder::ConvGeneralDilated(
|
||||||
@ -1423,7 +1436,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
|
|||||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||||
@ -1442,10 +1456,11 @@ XlaOp XlaBuilder::ConvGeneralDilated(
|
|||||||
ShapeInference::InferWindowFromDimensions(
|
ShapeInference::InferWindowFromDimensions(
|
||||||
window_dimensions, window_strides, padding,
|
window_dimensions, window_strides, padding,
|
||||||
lhs_dilation, rhs_dilation));
|
lhs_dilation, rhs_dilation));
|
||||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
TF_ASSIGN_OR_RETURN(
|
||||||
ShapeInference::InferConvolveShape(
|
Shape shape,
|
||||||
*lhs_shape, *rhs_shape, feature_group_count,
|
ShapeInference::InferConvolveShape(
|
||||||
batch_group_count, window, dimension_numbers));
|
*lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
|
||||||
|
window, dimension_numbers, preferred_element_type));
|
||||||
return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
|
return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
|
||||||
padding, lhs_dilation, rhs_dilation,
|
padding, lhs_dilation, rhs_dilation,
|
||||||
dimension_numbers, feature_group_count,
|
dimension_numbers, feature_group_count,
|
||||||
@ -1459,7 +1474,8 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
|
|||||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||||
std::vector<int64> window_dimensions(
|
std::vector<int64> window_dimensions(
|
||||||
@ -1472,10 +1488,11 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
|
|||||||
TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
|
TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
|
||||||
window_dimensions, window_strides,
|
window_dimensions, window_strides,
|
||||||
padding, lhs_dilation, rhs_dilation));
|
padding, lhs_dilation, rhs_dilation));
|
||||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
TF_ASSIGN_OR_RETURN(
|
||||||
ShapeInference::InferConvolveShape(
|
Shape shape,
|
||||||
*lhs_shape, *rhs_shape, feature_group_count,
|
ShapeInference::InferConvolveShape(
|
||||||
batch_group_count, window, dimension_numbers));
|
*lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
|
||||||
|
window, dimension_numbers, preferred_element_type));
|
||||||
|
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
@ -1499,14 +1516,15 @@ XlaOp XlaBuilder::DynamicConvInputGrad(
|
|||||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HloInstructionProto instr,
|
HloInstructionProto instr,
|
||||||
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
|
DynamicConvInstruction(
|
||||||
rhs_dilation, dimension_numbers,
|
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||||
feature_group_count, batch_group_count,
|
dimension_numbers, feature_group_count, batch_group_count,
|
||||||
precision_config, padding_type));
|
precision_config, padding_type, preferred_element_type));
|
||||||
|
|
||||||
instr.set_custom_call_target("DynamicConvolutionInputGrad");
|
instr.set_custom_call_target("DynamicConvolutionInputGrad");
|
||||||
|
|
||||||
@ -1521,14 +1539,16 @@ XlaOp XlaBuilder::DynamicConvKernelGrad(
|
|||||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HloInstructionProto instr,
|
HloInstructionProto instr,
|
||||||
DynamicConvInstruction(activations, gradients, window_strides, padding,
|
DynamicConvInstruction(activations, gradients, window_strides, padding,
|
||||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||||
feature_group_count, batch_group_count,
|
feature_group_count, batch_group_count,
|
||||||
precision_config, padding_type));
|
precision_config, padding_type,
|
||||||
|
preferred_element_type));
|
||||||
|
|
||||||
instr.set_custom_call_target("DynamicConvolutionKernelGrad");
|
instr.set_custom_call_target("DynamicConvolutionKernelGrad");
|
||||||
// The gradient of kernel has kernel shape and shouldn't have any dynamic
|
// The gradient of kernel has kernel shape and shouldn't have any dynamic
|
||||||
@ -1545,14 +1565,15 @@ XlaOp XlaBuilder::DynamicConvForward(
|
|||||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HloInstructionProto instr,
|
HloInstructionProto instr,
|
||||||
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
|
DynamicConvInstruction(
|
||||||
rhs_dilation, dimension_numbers,
|
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||||
feature_group_count, batch_group_count,
|
dimension_numbers, feature_group_count, batch_group_count,
|
||||||
precision_config, padding_type));
|
precision_config, padding_type, preferred_element_type));
|
||||||
instr.set_custom_call_target("DynamicConvolutionForward");
|
instr.set_custom_call_target("DynamicConvolutionForward");
|
||||||
|
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
|
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
|
||||||
@ -3331,6 +3352,11 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!need_rewrite) {
|
if (!need_rewrite) {
|
||||||
|
if (opcode == HloOpcode::kCompare) {
|
||||||
|
CHECK(!instr_proto->comparison_type().empty());
|
||||||
|
new_instr->set_comparison_type(
|
||||||
|
ComparisonTypeToString(Comparison::DefaultComparisonType(PRED)));
|
||||||
|
}
|
||||||
*new_instr->mutable_name() =
|
*new_instr->mutable_name() =
|
||||||
GetFullName(instr_proto->opcode(), kNameSeparator, id);
|
GetFullName(instr_proto->opcode(), kNameSeparator, id);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -3990,11 +4016,26 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
|
|||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
ComparisonDirection comparison_direction) {
|
||||||
|
auto b = lhs.builder();
|
||||||
|
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
|
||||||
|
auto operand_element_type = operand_shape.element_type();
|
||||||
|
auto compare_type =
|
||||||
|
primitive_util::IsFloatingPointType(operand_element_type)
|
||||||
|
? Comparison::Type::kFloatTotalOrder
|
||||||
|
: Comparison::DefaultComparisonType(operand_element_type);
|
||||||
|
return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
|
||||||
|
compare_type);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq,
|
ComparisonDirection::kEq);
|
||||||
compare_type);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
|
||||||
@ -4004,9 +4045,8 @@ XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
|
|||||||
|
|
||||||
XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe,
|
ComparisonDirection::kNe);
|
||||||
compare_type);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
|
||||||
@ -4016,9 +4056,8 @@ XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
|
|||||||
|
|
||||||
XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe,
|
ComparisonDirection::kGe);
|
||||||
compare_type);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
|
||||||
@ -4028,9 +4067,8 @@ XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
|
|||||||
|
|
||||||
XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt,
|
ComparisonDirection::kGt);
|
||||||
compare_type);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Le(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Le(const XlaOp lhs, const XlaOp rhs,
|
||||||
@ -4040,10 +4078,10 @@ XlaOp Le(const XlaOp lhs, const XlaOp rhs,
|
|||||||
|
|
||||||
XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe,
|
ComparisonDirection::kLe);
|
||||||
compare_type);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
|
||||||
@ -4051,8 +4089,8 @@ XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
|
|||||||
|
|
||||||
XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt,
|
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||||
Comparison::Type::kFloatTotalOrder);
|
ComparisonDirection::kLt);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
||||||
@ -4074,44 +4112,49 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
return lhs.builder()->Dot(lhs, rhs, precision_config);
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
|
return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
|
XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
|
||||||
const DotDimensionNumbers& dimension_numbers,
|
const DotDimensionNumbers& dimension_numbers,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
|
return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
|
||||||
precision_config);
|
precision_config, preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> window_strides, Padding padding,
|
absl::Span<const int64> window_strides, Padding padding,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
|
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
|
||||||
feature_group_count, batch_group_count,
|
feature_group_count, batch_group_count,
|
||||||
precision_config);
|
precision_config, preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp ConvWithGeneralPadding(const XlaOp lhs, const XlaOp rhs,
|
XlaOp ConvWithGeneralPadding(
|
||||||
absl::Span<const int64> window_strides,
|
const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return lhs.builder()->ConvWithGeneralPadding(
|
return lhs.builder()->ConvWithGeneralPadding(
|
||||||
lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
|
lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
|
||||||
precision_config);
|
precision_config, preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp ConvWithGeneralDimensions(
|
XlaOp ConvWithGeneralDimensions(
|
||||||
const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
|
const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return lhs.builder()->ConvWithGeneralDimensions(
|
return lhs.builder()->ConvWithGeneralDimensions(
|
||||||
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
|
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
|
||||||
batch_group_count, precision_config);
|
batch_group_count, precision_config, preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
|
XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
|
||||||
@ -4119,10 +4162,11 @@ XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
|
|||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
dimension_numbers, feature_group_count,
|
return lhs.builder()->ConvGeneral(
|
||||||
batch_group_count, precision_config);
|
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
|
||||||
|
batch_group_count, precision_config, preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
|
XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
|
||||||
@ -4132,26 +4176,27 @@ XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
|
|||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config) {
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return lhs.builder()->ConvGeneralDilated(
|
return lhs.builder()->ConvGeneralDilated(
|
||||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||||
dimension_numbers, feature_group_count, batch_group_count,
|
dimension_numbers, feature_group_count, batch_group_count,
|
||||||
precision_config);
|
precision_config, preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp DynamicConvInputGrad(XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
|
XlaOp DynamicConvInputGrad(
|
||||||
absl::Span<const int64> window_strides,
|
XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const int64> window_strides,
|
||||||
absl::Span<const int64> lhs_dilation,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config,
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
PaddingType padding_type) {
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return lhs.builder()->DynamicConvInputGrad(
|
return lhs.builder()->DynamicConvInputGrad(
|
||||||
input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
|
input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
|
||||||
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
|
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
|
||||||
precision_config, padding_type);
|
precision_config, padding_type, preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp DynamicConvKernelGrad(
|
XlaOp DynamicConvKernelGrad(
|
||||||
@ -4160,11 +4205,12 @@ XlaOp DynamicConvKernelGrad(
|
|||||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return activations.builder()->DynamicConvKernelGrad(
|
return activations.builder()->DynamicConvKernelGrad(
|
||||||
activations, gradients, window_strides, padding, lhs_dilation,
|
activations, gradients, window_strides, padding, lhs_dilation,
|
||||||
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
|
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
|
||||||
precision_config, padding_type);
|
precision_config, padding_type, preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
|
XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
|
||||||
@ -4175,11 +4221,12 @@ XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
|
|||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config,
|
const PrecisionConfig* precision_config,
|
||||||
PaddingType padding_type) {
|
PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type) {
|
||||||
return lhs.builder()->DynamicConvForward(
|
return lhs.builder()->DynamicConvForward(
|
||||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||||
dimension_numbers, feature_group_count, batch_group_count,
|
dimension_numbers, feature_group_count, batch_group_count,
|
||||||
precision_config, padding_type);
|
precision_config, padding_type, preferred_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Fft(const XlaOp operand, FftType fft_type,
|
XlaOp Fft(const XlaOp operand, FftType fft_type,
|
||||||
|
@ -521,56 +521,63 @@ class XlaBuilder {
|
|||||||
XlaOp tuple_data,
|
XlaOp tuple_data,
|
||||||
int64 index);
|
int64 index);
|
||||||
|
|
||||||
XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
XlaOp Dot(
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
XlaOp DotGeneral(
|
||||||
const DotDimensionNumbers& dimension_numbers,
|
XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp Conv(
|
||||||
Padding padding, int64 feature_group_count = 1,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
int64 batch_group_count = 1,
|
Padding padding, int64 feature_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
int64 batch_group_count = 1,
|
||||||
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
XlaOp ConvWithGeneralPadding(
|
XlaOp ConvWithGeneralPadding(
|
||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
XlaOp ConvWithGeneralDimensions(
|
XlaOp ConvWithGeneralDimensions(
|
||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
|
XlaOp ConvGeneral(
|
||||||
absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
|
XlaOp ConvGeneralDilated(
|
||||||
absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
absl::Span<const int64> lhs_dilation,
|
absl::Span<const int64> lhs_dilation,
|
||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count = 1,
|
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||||
int64 batch_group_count = 1,
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
|
XlaOp DynamicConvForward(
|
||||||
absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
absl::Span<const int64> lhs_dilation,
|
absl::Span<const int64> lhs_dilation,
|
||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config,
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
PaddingType padding_type);
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
XlaOp DynamicConvInputGrad(
|
XlaOp DynamicConvInputGrad(
|
||||||
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||||
@ -580,7 +587,8 @@ class XlaBuilder {
|
|||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
XlaOp DynamicConvKernelGrad(
|
XlaOp DynamicConvKernelGrad(
|
||||||
XlaOp activations, XlaOp gradients,
|
XlaOp activations, XlaOp gradients,
|
||||||
@ -590,7 +598,8 @@ class XlaBuilder {
|
|||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
StatusOr<HloInstructionProto> DynamicConvInstruction(
|
StatusOr<HloInstructionProto> DynamicConvInstruction(
|
||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
@ -599,7 +608,8 @@ class XlaBuilder {
|
|||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
|
virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
|
||||||
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
|
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
|
||||||
@ -1098,10 +1108,12 @@ class XlaBuilder {
|
|||||||
ComparisonDirection direction,
|
ComparisonDirection direction,
|
||||||
Comparison::Type compare_type);
|
Comparison::Type compare_type);
|
||||||
friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
||||||
const PrecisionConfig* precision_config);
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
||||||
const DotDimensionNumbers& dimension_number,
|
const DotDimensionNumbers& dimension_number,
|
||||||
const PrecisionConfig* precision_config);
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
virtual StatusOr<XlaOp> DotGeneralInternal(
|
virtual StatusOr<XlaOp> DotGeneralInternal(
|
||||||
const Shape& shape, XlaOp lhs, XlaOp rhs,
|
const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||||
const DotDimensionNumbers& dimension_number,
|
const DotDimensionNumbers& dimension_number,
|
||||||
@ -1109,23 +1121,27 @@ class XlaBuilder {
|
|||||||
friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
|
friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> window_strides, Padding padding,
|
absl::Span<const int64> window_strides, Padding padding,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config);
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
friend XlaOp ConvWithGeneralPadding(
|
friend XlaOp ConvWithGeneralPadding(
|
||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config);
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
friend XlaOp ConvWithGeneralDimensions(
|
friend XlaOp ConvWithGeneralDimensions(
|
||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_confige);
|
const PrecisionConfig* precision_config,
|
||||||
friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
absl::Span<const int64> window_strides,
|
friend XlaOp ConvGeneral(
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
const PrecisionConfig* precision_config);
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
friend XlaOp DynamicConvForward(
|
friend XlaOp DynamicConvForward(
|
||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
@ -1133,7 +1149,8 @@ class XlaBuilder {
|
|||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
friend XlaOp DynamicConvKernelGrad(
|
friend XlaOp DynamicConvKernelGrad(
|
||||||
XlaOp activations, XlaOp gradients,
|
XlaOp activations, XlaOp gradients,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
@ -1142,7 +1159,8 @@ class XlaBuilder {
|
|||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
friend XlaOp DynamicConvInputGrad(
|
friend XlaOp DynamicConvInputGrad(
|
||||||
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
@ -1151,7 +1169,8 @@ class XlaBuilder {
|
|||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
|
|
||||||
friend XlaOp ConvKernelGrad(
|
friend XlaOp ConvKernelGrad(
|
||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
@ -1160,7 +1179,8 @@ class XlaBuilder {
|
|||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config);
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
|
|
||||||
friend XlaOp ConvGeneralDilated(
|
friend XlaOp ConvGeneralDilated(
|
||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
@ -1169,7 +1189,8 @@ class XlaBuilder {
|
|||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config);
|
const PrecisionConfig* precision_config,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type);
|
||||||
friend XlaOp Fft(XlaOp operand, FftType fft_type,
|
friend XlaOp Fft(XlaOp operand, FftType fft_type,
|
||||||
absl::Span<const int64> fft_length);
|
absl::Span<const int64> fft_length);
|
||||||
friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||||
@ -1813,28 +1834,31 @@ XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
|
|||||||
|
|
||||||
// Enqueues a dot instruction onto the computation.
|
// Enqueues a dot instruction onto the computation.
|
||||||
XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
// Enqueues a general dot instruction onto the computation.
|
// Enqueues a general dot instruction onto the computation.
|
||||||
XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
XlaOp DotGeneral(
|
||||||
const DotDimensionNumbers& dimension_numbers,
|
XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
// Enqueues a convolution instruction onto the computation, which uses the
|
// Enqueues a convolution instruction onto the computation, which uses the
|
||||||
// default convolution dimension numbers.
|
// default convolution dimension numbers.
|
||||||
XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp Conv(
|
||||||
Padding padding, int64 feature_group_count = 1,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
int64 batch_group_count = 1,
|
Padding padding, int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
// Enqueues a convolution instruction onto the computation, with the caller
|
// Enqueues a convolution instruction onto the computation, with the caller
|
||||||
// provided padding configuration in the format returned by MakePadding().
|
// provided padding configuration in the format returned by MakePadding().
|
||||||
XlaOp ConvWithGeneralPadding(XlaOp lhs, XlaOp rhs,
|
XlaOp ConvWithGeneralPadding(
|
||||||
absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
int64 feature_group_count = 1,
|
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||||
int64 batch_group_count = 1,
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
// Enqueues a convolution instruction onto the computation, with the caller
|
// Enqueues a convolution instruction onto the computation, with the caller
|
||||||
// provided dimension numbers configuration.
|
// provided dimension numbers configuration.
|
||||||
@ -1842,47 +1866,48 @@ XlaOp ConvWithGeneralDimensions(
|
|||||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
// Enqueues a convolution instruction onto the computation, with the caller
|
// Enqueues a convolution instruction onto the computation, with the caller
|
||||||
// provided padding configuration as well as the dimension numbers.
|
// provided padding configuration as well as the dimension numbers.
|
||||||
XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
XlaOp ConvGeneral(
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||||
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
// Enqueues a convolution instruction onto the computation, with the caller
|
// Enqueues a convolution instruction onto the computation, with the caller
|
||||||
// provided padding configuration, dilation factors and dimension numbers.
|
// provided padding configuration, dilation factors and dimension numbers.
|
||||||
XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
|
XlaOp ConvGeneralDilated(
|
||||||
absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
absl::Span<const int64> lhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
absl::Span<const int64> rhs_dilation,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||||
int64 feature_group_count = 1,
|
const PrecisionConfig* precision_config = nullptr,
|
||||||
int64 batch_group_count = 1,
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
|
||||||
|
|
||||||
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
|
XlaOp DynamicConvForward(
|
||||||
absl::Span<const int64> window_strides,
|
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
absl::Span<const int64> lhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
absl::Span<const int64> rhs_dilation,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
const PrecisionConfig* precision_config,
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
PaddingType padding_type);
|
|
||||||
|
|
||||||
XlaOp DynamicConvInputGrad(XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
XlaOp DynamicConvInputGrad(
|
||||||
absl::Span<const int64> window_strides,
|
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const int64> window_strides,
|
||||||
absl::Span<const int64> lhs_dilation,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config,
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
PaddingType padding_type);
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
XlaOp DynamicConvKernelGrad(
|
XlaOp DynamicConvKernelGrad(
|
||||||
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
|
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
|
||||||
@ -1890,7 +1915,8 @@ XlaOp DynamicConvKernelGrad(
|
|||||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||||
int64 feature_group_count, int64 batch_group_count,
|
int64 feature_group_count, int64 batch_group_count,
|
||||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||||
|
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||||
|
|
||||||
// Enqueues an FFT instruction onto the computation, of the given type and
|
// Enqueues an FFT instruction onto the computation, of the given type and
|
||||||
// with the given FFT length.
|
// with the given FFT length.
|
||||||
|
@ -338,8 +338,7 @@ TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) {
|
|||||||
/*broadcast_dimensions=*/{0, 1, 2});
|
/*broadcast_dimensions=*/{0, 1, 2});
|
||||||
auto statusor = BuildHloModule(&b);
|
auto statusor = BuildHloModule(&b);
|
||||||
ASSERT_FALSE(statusor.ok());
|
ASSERT_FALSE(statusor.ok());
|
||||||
EXPECT_THAT(statusor.status().error_message(),
|
EXPECT_THAT(statusor.status().error_message(), HasSubstr("invalid shape"));
|
||||||
HasSubstr("shape's dimensions must not be < 0"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
|
TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
|
||||||
@ -1066,6 +1065,56 @@ TEST_F(XlaBuilderTest, DynamicTranspose) {
|
|||||||
<< result_shape;
|
<< result_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaBuilderTest, DotWithPreferredElementType) {
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
Shape p0_shape = ShapeUtil::MakeShape(U8, {2, 3});
|
||||||
|
Shape p1_shape = ShapeUtil::MakeShape(U16, {3, 2});
|
||||||
|
auto p0 = Parameter(&b, 0, p0_shape, "p0");
|
||||||
|
auto p1 = Parameter(&b, 1, p1_shape, "p1");
|
||||||
|
|
||||||
|
DotDimensionNumbers dnums;
|
||||||
|
dnums.add_lhs_contracting_dimensions(1);
|
||||||
|
dnums.add_rhs_contracting_dimensions(0);
|
||||||
|
DotGeneral(p0, p1, dnums, /*precision_config=*/nullptr,
|
||||||
|
/*preferred_element_type=*/U32);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||||
|
const Shape& result_shape =
|
||||||
|
module->entry_computation()->root_instruction()->shape();
|
||||||
|
ASSERT_TRUE(
|
||||||
|
ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {2, 2}), result_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaBuilderTest, ConvolutionWithPreferredElementType) {
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128});
|
||||||
|
Shape p1_shape = ShapeUtil::MakeShape(S8, {2, 2, 128, 8});
|
||||||
|
auto p0 = Parameter(&b, 0, p0_shape, "p0");
|
||||||
|
auto p1 = Parameter(&b, 1, p1_shape, "p1");
|
||||||
|
|
||||||
|
ConvolutionDimensionNumbers dnums;
|
||||||
|
dnums.set_input_batch_dimension(0);
|
||||||
|
dnums.set_output_batch_dimension(0);
|
||||||
|
dnums.add_input_spatial_dimensions(1);
|
||||||
|
dnums.add_output_spatial_dimensions(1);
|
||||||
|
dnums.add_input_spatial_dimensions(2);
|
||||||
|
dnums.add_output_spatial_dimensions(2);
|
||||||
|
dnums.set_input_feature_dimension(3);
|
||||||
|
dnums.set_output_feature_dimension(3);
|
||||||
|
dnums.add_kernel_spatial_dimensions(0);
|
||||||
|
dnums.add_kernel_spatial_dimensions(1);
|
||||||
|
dnums.set_kernel_input_feature_dimension(2);
|
||||||
|
dnums.set_kernel_output_feature_dimension(3);
|
||||||
|
ConvWithGeneralDimensions(p0, p1, {1, 1}, Padding::kValid, dnums,
|
||||||
|
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||||
|
/*precision_config=*/nullptr,
|
||||||
|
/*preferred_element_type=*/S32);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||||
|
const Shape& result_shape =
|
||||||
|
module->entry_computation()->root_instruction()->shape();
|
||||||
|
ASSERT_TRUE(
|
||||||
|
ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {1, 1, 1, 8}), result_shape));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
|
TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});
|
AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});
|
||||||
|
@ -243,6 +243,7 @@ cc_library(
|
|||||||
hdrs = ["nvidia_gpu_device.h"],
|
hdrs = ["nvidia_gpu_device.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":pjrt_client",
|
":pjrt_client",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
|
@ -15,12 +15,15 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
|
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
|
||||||
#ifdef NCCL_ENABLED
|
#ifdef NCCL_ENABLED
|
||||||
#include "third_party/nccl/nccl.h"
|
#include "third_party/nccl/nccl.h"
|
||||||
#endif // NCCL_ENABLED
|
#endif // NCCL_ENABLED
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/common_runtime/device/device_host_allocator.h"
|
#include "tensorflow/core/common_runtime/device/device_host_allocator.h"
|
||||||
#include "tensorflow/core/common_runtime/device/device_id.h"
|
#include "tensorflow/core/common_runtime/device/device_id.h"
|
||||||
@ -166,56 +169,57 @@ std::unique_ptr<tensorflow::BFCAllocator> GetGpuHostAllocator(
|
|||||||
|
|
||||||
// A table mapping NcclCliqueKeys to ncclUniqueId values encoded as strings.
|
// A table mapping NcclCliqueKeys to ncclUniqueId values encoded as strings.
|
||||||
// In a distributed setup the table of NCCL IDs is kept on the master node
|
// In a distributed setup the table of NCCL IDs is kept on the master node
|
||||||
// (node 0). Currently node 0 is the only node that generates ncclUniqueIds;
|
// (node 0). The node of the first participating device will create the unique
|
||||||
// see the TODO below.
|
// id.
|
||||||
class NcclIdStore {
|
class NcclIdStore {
|
||||||
public:
|
public:
|
||||||
NcclIdStore(int node_id, std::shared_ptr<DistributedRuntimeClient> client)
|
NcclIdStore(int node_id, std::shared_ptr<DistributedRuntimeClient> client,
|
||||||
: node_id_(node_id), client_(std::move(client)) {}
|
absl::flat_hash_map<GlobalDeviceId, int> device_to_node)
|
||||||
|
: node_id_(node_id),
|
||||||
|
client_(std::move(client)),
|
||||||
|
device_to_node_(std::move(device_to_node)) {}
|
||||||
|
|
||||||
StatusOr<std::string> GetNcclUniqueId(const NcclCliqueKey& key);
|
StatusOr<std::string> GetNcclUniqueId(const NcclCliqueKey& key);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const int node_id_;
|
const int node_id_;
|
||||||
const std::shared_ptr<DistributedRuntimeClient> client_;
|
const std::shared_ptr<DistributedRuntimeClient> client_;
|
||||||
|
const absl::flat_hash_map<GlobalDeviceId, int> device_to_node_;
|
||||||
|
|
||||||
absl::Mutex mu_;
|
absl::Mutex mu_;
|
||||||
absl::flat_hash_map<std::string, std::string> cache_ ABSL_GUARDED_BY(mu_);
|
absl::flat_hash_map<NcclCliqueKey, std::string> cache_ ABSL_GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
StatusOr<std::string> NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) {
|
StatusOr<std::string> NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) {
|
||||||
std::string key_string = GlobalDeviceIdsToString(key.devices());
|
// The caller must ensure that threads calling this method concurrently have
|
||||||
|
// unique keys, otherwise the global key-value store may hold the wrong value.
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
auto it = cache_.find(key_string);
|
auto it = cache_.find(key);
|
||||||
if (it != cache_.end()) {
|
if (it != cache_.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto result = [&]() -> StatusOr<std::string> {
|
std::string id_string;
|
||||||
// TODO(phawkins): this will deadlock if node 0 is not involved in the
|
int primary_node_id = device_to_node_.at(key.devices()[0]);
|
||||||
// computation. Add support for computations that only use a subset of
|
if (node_id_ == primary_node_id) {
|
||||||
// replicas.
|
|
||||||
if (node_id_ == 0) {
|
|
||||||
#ifdef NCCL_ENABLED
|
#ifdef NCCL_ENABLED
|
||||||
ncclUniqueId id;
|
ncclUniqueId id;
|
||||||
ncclResult_t r = ncclGetUniqueId(&id);
|
ncclResult_t r = ncclGetUniqueId(&id);
|
||||||
TF_RET_CHECK(r == ncclSuccess);
|
TF_RET_CHECK(r == ncclSuccess);
|
||||||
std::string value(id.internal, NCCL_UNIQUE_ID_BYTES);
|
id_string = std::string(id.internal, NCCL_UNIQUE_ID_BYTES);
|
||||||
TF_RETURN_IF_ERROR(client_->KeyValueSet(key_string, value));
|
TF_RETURN_IF_ERROR(client_->KeyValueSet(key.ToString(), id_string));
|
||||||
return value;
|
|
||||||
#else
|
#else
|
||||||
return FailedPrecondition("NCCL support was not built into XLA binary.");
|
return FailedPrecondition("NCCL support was not built into XLA binary.");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
return client_->BlockingKeyValueGet(key_string, absl::Minutes(5));
|
TF_ASSIGN_OR_RETURN(id_string, client_->BlockingKeyValueGet(
|
||||||
}
|
key.ToString(), absl::Minutes(5)));
|
||||||
}();
|
|
||||||
if (!result.ok()) {
|
|
||||||
return result.status();
|
|
||||||
}
|
}
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
return cache_.emplace(key_string, result.ValueOrDie()).first->second;
|
auto result = cache_.emplace(key, std::move(id_string));
|
||||||
|
TF_RET_CHECK(result.second) << "Unique ID already in cache.";
|
||||||
|
return result.first->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
||||||
@ -258,8 +262,11 @@ Status BuildDistributedDevices(
|
|||||||
distributed_client->EnumerateDevices(local_topology, &global_topology));
|
distributed_client->EnumerateDevices(local_topology, &global_topology));
|
||||||
|
|
||||||
std::vector<GlobalDeviceId> gpu_device_ids(local_device_states.size());
|
std::vector<GlobalDeviceId> gpu_device_ids(local_device_states.size());
|
||||||
|
absl::flat_hash_map<GlobalDeviceId, int> device_to_node;
|
||||||
for (const LocalTopologyProto& node : global_topology.nodes()) {
|
for (const LocalTopologyProto& node : global_topology.nodes()) {
|
||||||
for (const DeviceProto& device_proto : node.devices()) {
|
for (const DeviceProto& device_proto : node.devices()) {
|
||||||
|
GlobalDeviceId global_device_id(device_proto.global_device_id());
|
||||||
|
device_to_node[global_device_id] = node.node_id();
|
||||||
std::unique_ptr<LocalDeviceState> local_device;
|
std::unique_ptr<LocalDeviceState> local_device;
|
||||||
if (node.node_id() == node_id) {
|
if (node.node_id() == node_id) {
|
||||||
TF_RET_CHECK(device_proto.local_device_ordinal() >= 0 &&
|
TF_RET_CHECK(device_proto.local_device_ordinal() >= 0 &&
|
||||||
@ -269,8 +276,7 @@ Status BuildDistributedDevices(
|
|||||||
nullptr);
|
nullptr);
|
||||||
local_device =
|
local_device =
|
||||||
std::move(local_device_states[device_proto.local_device_ordinal()]);
|
std::move(local_device_states[device_proto.local_device_ordinal()]);
|
||||||
gpu_device_ids[device_proto.local_device_ordinal()] =
|
gpu_device_ids[device_proto.local_device_ordinal()] = global_device_id;
|
||||||
GlobalDeviceId(device_proto.global_device_id());
|
|
||||||
}
|
}
|
||||||
auto device = absl::make_unique<GpuDevice>(
|
auto device = absl::make_unique<GpuDevice>(
|
||||||
device_proto.global_device_id(), std::move(local_device),
|
device_proto.global_device_id(), std::move(local_device),
|
||||||
@ -283,8 +289,8 @@ Status BuildDistributedDevices(
|
|||||||
}
|
}
|
||||||
gpu_executable_run_options->set_gpu_global_device_ids(
|
gpu_executable_run_options->set_gpu_global_device_ids(
|
||||||
std::move(gpu_device_ids));
|
std::move(gpu_device_ids));
|
||||||
auto nccl_id_store =
|
auto nccl_id_store = std::make_shared<NcclIdStore>(
|
||||||
std::make_shared<NcclIdStore>(node_id, distributed_client);
|
node_id, distributed_client, device_to_node);
|
||||||
gpu_executable_run_options->set_nccl_unique_id_callback(
|
gpu_executable_run_options->set_nccl_unique_id_callback(
|
||||||
[nccl_id_store](const NcclCliqueKey& key) {
|
[nccl_id_store](const NcclCliqueKey& key) {
|
||||||
return nccl_id_store->GetNcclUniqueId(key);
|
return nccl_id_store->GetNcclUniqueId(key);
|
||||||
|
@ -597,19 +597,25 @@ struct TypeDescriptor<uint16> {
|
|||||||
static int Dtype() { return NPY_UINT16; }
|
static int Dtype() { return NPY_UINT16; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// We register "int", "long", and "long long" types for portability across
|
||||||
|
// Linux, where "int" and "long" are the same type, and Windows, where "long"
|
||||||
|
// and "longlong" are the same type.
|
||||||
template <>
|
template <>
|
||||||
struct TypeDescriptor<uint32> {
|
struct TypeDescriptor<unsigned int> {
|
||||||
typedef uint32 T;
|
typedef unsigned int T;
|
||||||
static int Dtype() { return NPY_UINT32; }
|
static int Dtype() { return NPY_UINT; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Uint64Type>
|
template <>
|
||||||
struct TypeDescriptor<
|
struct TypeDescriptor<unsigned long> { // NOLINT
|
||||||
Uint64Type, typename std::enable_if<std::is_integral<Uint64Type>::value &&
|
typedef unsigned long T; // NOLINT
|
||||||
!std::is_signed<Uint64Type>::value &&
|
static int Dtype() { return NPY_ULONG; }
|
||||||
sizeof(Uint64Type) == 8>::type> {
|
};
|
||||||
typedef Uint64Type T;
|
|
||||||
static int Dtype() { return NPY_UINT64; }
|
template <>
|
||||||
|
struct TypeDescriptor<unsigned long long> { // NOLINT
|
||||||
|
typedef unsigned long long T; // NOLINT
|
||||||
|
static int Dtype() { return NPY_ULONGLONG; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -625,18 +631,21 @@ struct TypeDescriptor<int16> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct TypeDescriptor<int32> {
|
struct TypeDescriptor<int> {
|
||||||
typedef int32 T;
|
typedef int T;
|
||||||
static int Dtype() { return NPY_INT32; }
|
static int Dtype() { return NPY_INT; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Int64Type>
|
template <>
|
||||||
struct TypeDescriptor<
|
struct TypeDescriptor<long> { // NOLINT
|
||||||
Int64Type, typename std::enable_if<std::is_integral<Int64Type>::value &&
|
typedef long T; // NOLINT
|
||||||
std::is_signed<Int64Type>::value &&
|
static int Dtype() { return NPY_LONG; }
|
||||||
sizeof(Int64Type) == 8>::type> {
|
};
|
||||||
typedef Int64Type T;
|
|
||||||
static int Dtype() { return NPY_INT64; }
|
template <>
|
||||||
|
struct TypeDescriptor<long long> { // NOLINT
|
||||||
|
typedef long long T; // NOLINT
|
||||||
|
static int Dtype() { return NPY_LONGLONG; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -1354,7 +1363,15 @@ bool Initialize() {
|
|||||||
if (!RegisterBfloat16Cast<uint16>(NPY_UINT16, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<uint16>(NPY_UINT16, /*cast_is_safe=*/false)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<uint32>(NPY_UINT32, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT, /*cast_is_safe=*/false)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG, // NOLINT
|
||||||
|
/*cast_is_safe=*/false)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!RegisterBfloat16Cast<unsigned long long>( // NOLINT
|
||||||
|
NPY_ULONGLONG, /*cast_is_safe=*/false)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<uint64>(NPY_UINT64, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<uint64>(NPY_UINT64, /*cast_is_safe=*/false)) {
|
||||||
@ -1366,14 +1383,15 @@ bool Initialize() {
|
|||||||
if (!RegisterBfloat16Cast<int16>(NPY_INT16, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<int16>(NPY_INT16, /*cast_is_safe=*/false)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<int32>(NPY_INT32, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<int>(NPY_INT, /*cast_is_safe=*/false)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<int64>(NPY_INT64, /*cast_is_safe=*/false)) {
|
if (!RegisterBfloat16Cast<long>(NPY_LONG, // NOLINT
|
||||||
|
/*cast_is_safe=*/false)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!RegisterBfloat16Cast<npy_longlong>(NPY_LONGLONG,
|
if (!RegisterBfloat16Cast<long long>( // NOLINT
|
||||||
/*cast_is_safe=*/false)) {
|
NPY_LONGLONG, /*cast_is_safe=*/false)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Following the numpy convention. imag part is dropped when converting to
|
// Following the numpy convention. imag part is dropped when converting to
|
||||||
|
@ -293,7 +293,7 @@ class Bfloat16NumPyTest(parameterized.TestCase):
|
|||||||
for dtype in [
|
for dtype in [
|
||||||
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
|
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
|
||||||
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
|
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
|
||||||
np.uint64
|
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
|
||||||
]:
|
]:
|
||||||
x = np.array([[1, 2, 3]], dtype=dtype)
|
x = np.array([[1, 2, 3]], dtype=dtype)
|
||||||
y = x.astype(bfloat16)
|
y = x.astype(bfloat16)
|
||||||
|
@ -595,7 +595,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
|||||||
|
|
||||||
static const auto* xla_module =
|
static const auto* xla_module =
|
||||||
new py::module(py::module::import("jax.interpreters.xla"));
|
new py::module(py::module::import("jax.interpreters.xla"));
|
||||||
const auto& device_array = xla_module->attr("DeviceArray");
|
const auto& device_array = xla_module->attr("_DeviceArray");
|
||||||
|
|
||||||
static const auto* numpy_module = new py::module(py::module::import("numpy"));
|
static const auto* numpy_module = new py::module(py::module::import("numpy"));
|
||||||
const auto& np_array = numpy_module->attr("array");
|
const auto& np_array = numpy_module->attr("array");
|
||||||
@ -613,7 +613,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
|||||||
for (py::handle arg : arguments.flat_dynamic_args) {
|
for (py::handle arg : arguments.flat_dynamic_args) {
|
||||||
// We specically only deal with DeviceArray (not ShardedDeviceArray).
|
// We specically only deal with DeviceArray (not ShardedDeviceArray).
|
||||||
// (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
|
// (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
|
||||||
if (arg.get_type().is(device_array)) {
|
if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) {
|
||||||
xla::PyBuffer* buffer;
|
xla::PyBuffer* buffer;
|
||||||
if (arg.attr("_device").is_none()) { // Skip non-sticky devices.
|
if (arg.attr("_device").is_none()) { // Skip non-sticky devices.
|
||||||
continue;
|
continue;
|
||||||
@ -653,7 +653,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
|||||||
xla::PjRtClient* pjrt_client = data_device->client();
|
xla::PjRtClient* pjrt_client = data_device->client();
|
||||||
|
|
||||||
for (py::handle arg : arguments.flat_dynamic_args) {
|
for (py::handle arg : arguments.flat_dynamic_args) {
|
||||||
if (arg.get_type().is(device_array)) {
|
if (py::isinstance<PyBuffer>(arg) || arg.get_type().is(device_array)) {
|
||||||
if (!HasTrivialLazyExpr(arg)) {
|
if (!HasTrivialLazyExpr(arg)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Non-trivial lazy expression not supported in C++. "
|
"Non-trivial lazy expression not supported in C++. "
|
||||||
|
@ -108,7 +108,8 @@ void BuildOpsSubmodule(py::module* m) {
|
|||||||
py::arg("lhs_dilation"), py::arg("rhs_dilation"),
|
py::arg("lhs_dilation"), py::arg("rhs_dilation"),
|
||||||
py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
|
py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
|
||||||
py::arg("batch_group_count") = 1,
|
py::arg("batch_group_count") = 1,
|
||||||
py::arg("precision_config") = nullptr);
|
py::arg("precision_config") = nullptr,
|
||||||
|
py::arg("preferred_element_type") = absl::nullopt);
|
||||||
ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
|
ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
|
||||||
py::arg("new_element_type"));
|
py::arg("new_element_type"));
|
||||||
ops.def(
|
ops.def(
|
||||||
@ -136,9 +137,11 @@ void BuildOpsSubmodule(py::module* m) {
|
|||||||
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
|
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
|
||||||
py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false);
|
py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false);
|
||||||
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
|
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
|
||||||
py::arg("precision_config") = nullptr);
|
py::arg("precision_config") = nullptr,
|
||||||
|
py::arg("preferred_element_type") = absl::nullopt);
|
||||||
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
|
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
|
||||||
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
|
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr,
|
||||||
|
py::arg("preferred_element_type") = absl::nullopt);
|
||||||
ops.def("DynamicSlice",
|
ops.def("DynamicSlice",
|
||||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
|
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
|
||||||
absl::Span<const int64>)>(&DynamicSlice),
|
absl::Span<const int64>)>(&DynamicSlice),
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user