Merge branch 'master' into fix_minimum_maximum
This commit is contained in:
commit
811c2a08ff
88
.bazelrc
88
.bazelrc
@ -100,9 +100,9 @@ build --apple_platform_type=macos
|
||||
# iOS configs for each architecture and the fat binary builds.
|
||||
build:ios --apple_platform_type=ios
|
||||
build:ios --apple_bitcode=embedded --copt=-fembed-bitcode
|
||||
build:ios --copt=-Wno-c++11-narrowing
|
||||
build:ios_armv7 --config=ios
|
||||
build:ios_armv7 --cpu=ios_armv7
|
||||
build:ios_armv7 --copt -Wno-c++11-narrowing
|
||||
build:ios_arm64 --config=ios
|
||||
build:ios_arm64 --cpu=ios_arm64
|
||||
build:ios_i386 --config=ios
|
||||
@ -111,7 +111,6 @@ build:ios_x86_64 --config=ios
|
||||
build:ios_x86_64 --cpu=ios_x86_64
|
||||
build:ios_fat --config=ios
|
||||
build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64
|
||||
build:ios_fat --copt -Wno-c++11-narrowing
|
||||
|
||||
# Config to use a mostly-static build and disable modular op registration
|
||||
# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
|
||||
@ -202,18 +201,25 @@ build --define=allow_oversize_protos=true
|
||||
build --spawn_strategy=standalone
|
||||
build -c opt
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
build --cxxopt=-std=c++14
|
||||
build --host_cxxopt=-std=c++14
|
||||
|
||||
# Make Bazel print out all options from rc files.
|
||||
build --announce_rc
|
||||
|
||||
# Other build flags.
|
||||
build --define=grpc_no_ares=true
|
||||
|
||||
# Prevent regression of https://github.com/bazelbuild/bazel/issues/7362
|
||||
build --incompatible_remove_legacy_whole_archive
|
||||
# See https://github.com/bazelbuild/bazel/issues/7362 for information on what
|
||||
# --incompatible_remove_legacy_whole_archive flag does.
|
||||
# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate
|
||||
# Tensorflow to the default, however test coverage wasn't enough to catch the
|
||||
# errors.
|
||||
# There is ongoing work on Bazel team's side to provide support for transitive
|
||||
# shared libraries. As part of migrating to transitive shared libraries, we
|
||||
# hope to provide a better mechanism for control over symbol exporting, and
|
||||
# then tackle this issue again.
|
||||
#
|
||||
# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library
|
||||
# archives in -whole_archive -no_whole_archive.
|
||||
build --noincompatible_remove_legacy_whole_archive
|
||||
|
||||
# Modular TF build options
|
||||
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
||||
@ -224,13 +230,55 @@ build:c++17 --cxxopt=-std=c++1z
|
||||
build:c++17 --cxxopt=-stdlib=libc++
|
||||
build:c++1z --config=c++17
|
||||
|
||||
# Default paths for TF_SYSTEM_LIBS
|
||||
build --define=PREFIX=/usr
|
||||
build --define=LIBDIR=$(PREFIX)/lib
|
||||
build --define=INCLUDEDIR=$(PREFIX)/include
|
||||
# Enable using platform specific build settings
|
||||
build --enable_platform_specific_config
|
||||
|
||||
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
||||
build --copt=-w
|
||||
build:linux --copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
|
||||
# Default paths for TF_SYSTEM_LIBS
|
||||
build:linux --define=PREFIX=/usr
|
||||
build:linux --define=LIBDIR=$(PREFIX)/lib
|
||||
build:linux --define=INCLUDEDIR=$(PREFIX)/include
|
||||
build:macos --define=PREFIX=/usr
|
||||
build:macos --define=LIBDIR=$(PREFIX)/lib
|
||||
build:macos --define=INCLUDEDIR=$(PREFIX)/include
|
||||
# TF_SYSTEM_LIBS do not work on windows.
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
build:linux --cxxopt=-std=c++14
|
||||
build:linux --host_cxxopt=-std=c++14
|
||||
build:macos --cxxopt=-std=c++14
|
||||
build:macos --host_cxxopt=-std=c++14
|
||||
build:windows --cxxopt=/std:c++14
|
||||
build:windows --host_cxxopt=/std:c++14
|
||||
|
||||
# On windows, we still link everything into a single DLL.
|
||||
build:windows --config=monolithic
|
||||
|
||||
# Make sure to include as little of windows.h as possible
|
||||
build:windows --copt=-DWIN32_LEAN_AND_MEAN
|
||||
build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
|
||||
build:windows --copt=-DNOGDI
|
||||
build:windows --host_copt=-DNOGDI
|
||||
|
||||
# Misc build options we need for windows.
|
||||
build:windows --linkopt=/DEBUG
|
||||
build:windows --host_linkopt=/DEBUG
|
||||
build:windows --linkopt=/OPT:REF
|
||||
build:windows --host_linkopt=/OPT:REF
|
||||
build:windows --linkopt=/OPT:ICF
|
||||
build:windows --host_linkopt=/OPT:ICF
|
||||
build:windows --experimental_strict_action_env=true
|
||||
build:windows --incompatible_windows_native_test_wrapper
|
||||
|
||||
# Verbose failure logs when something goes wrong
|
||||
build:windows --verbose_failures
|
||||
|
||||
# On windows, we never cross compile
|
||||
build:windows --distinct_host_configuration=false
|
||||
|
||||
# Suppress all warning messages.
|
||||
build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
||||
@ -335,20 +383,6 @@ build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_
|
||||
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
||||
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
|
||||
|
||||
# Misc build options we need for windows
|
||||
build:rbe_win --copt=-DWIN32_LEAN_AND_MEAN
|
||||
build:rbe_win --host_copt=-DWIN32_LEAN_AND_MEAN
|
||||
build:rbe_win --copt=-DNOGDI
|
||||
build:rbe_win --host_copt=-DNOGDI
|
||||
build:rbe_win --linkopt=/DEBUG
|
||||
build:rbe_win --host_linkopt=/DEBUG
|
||||
build:rbe_win --linkopt=/OPT:REF
|
||||
build:rbe_win --host_linkopt=/OPT:REF
|
||||
build:rbe_win --linkopt=/OPT:ICF
|
||||
build:rbe_win --host_linkopt=/OPT:ICF
|
||||
build:rbe_win --config=monolithic
|
||||
build:rbe_win --experimental_strict_action_env=true
|
||||
build:rbe_win --incompatible_windows_native_test_wrapper
|
||||
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
|
||||
build:rbe_win --define=override_eigen_strong_inline=true
|
||||
|
||||
|
@ -89,7 +89,7 @@ swift_rules_dependencies()
|
||||
# files, in case the parsing of those build files depends on the bazel
|
||||
# version we require here.
|
||||
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
|
||||
check_bazel_version_at_least("0.19.0")
|
||||
check_bazel_version_at_least("1.0.0")
|
||||
|
||||
load("//third_party/android:android_configure.bzl", "android_configure")
|
||||
android_configure(name="local_config_android")
|
||||
|
16
configure.py
16
configure.py
@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '0.27.1'
|
||||
_TF_MIN_BAZEL_VERSION = '1.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '1.1.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
@ -1232,20 +1232,6 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
|
||||
|
||||
def set_windows_build_flags(environ_cp):
|
||||
"""Set Windows specific build options."""
|
||||
# The non-monolithic build is not supported yet
|
||||
write_to_bazelrc('build --config monolithic')
|
||||
# Suppress warning messages
|
||||
write_to_bazelrc('build --copt=-w --host_copt=-w')
|
||||
# Fix winsock2.h conflicts
|
||||
write_to_bazelrc(
|
||||
'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN '
|
||||
'--copt=-DNOGDI --host_copt=-DNOGDI')
|
||||
# Output more verbose information when something goes wrong
|
||||
write_to_bazelrc('build --verbose_failures')
|
||||
# The host and target platforms are the same in Windows build. So we don't
|
||||
# have to distinct them. This avoids building the same targets twice.
|
||||
write_to_bazelrc('build --distinct_host_configuration=false')
|
||||
|
||||
if is_reduced_optimize_huge_functions_available(environ_cp):
|
||||
write_to_bazelrc(
|
||||
'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions'
|
||||
|
@ -2,10 +2,7 @@
|
||||
# TensorFlow is a computational framework, primarily for use in machine
|
||||
# learning applications.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "VERSION")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
|
||||
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_additional_binary_deps",
|
||||
@ -450,6 +447,7 @@ config_setting(
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//learning/brain/swift/x10/...",
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
|
@ -119,11 +119,11 @@ def _running_from_pip_package():
|
||||
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
||||
|
||||
if _running_from_pip_package():
|
||||
for s in _site_packages_dirs:
|
||||
for _s in _site_packages_dirs:
|
||||
# TODO(gunan): Add sanity checks to loaded modules here.
|
||||
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(plugin_dir):
|
||||
_ll.load_library(plugin_dir)
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
|
||||
# Add module aliases
|
||||
if hasattr(_current_module, 'keras'):
|
||||
@ -136,3 +136,5 @@ if hasattr(_current_module, 'keras'):
|
||||
setattr(_current_module, "optimizers", optimizers)
|
||||
setattr(_current_module, "initializers", initializers)
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
# __all__ PLACEHOLDER
|
||||
|
@ -132,9 +132,10 @@ def _running_from_pip_package():
|
||||
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
||||
|
||||
if _running_from_pip_package():
|
||||
for s in _site_packages_dirs:
|
||||
for _s in _site_packages_dirs:
|
||||
# TODO(gunan): Add sanity checks to loaded modules here.
|
||||
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(plugin_dir):
|
||||
_ll.load_library(plugin_dir)
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
|
||||
# __all__ PLACEHOLDER
|
||||
|
@ -196,6 +196,12 @@ cc_library(
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_status_headers",
|
||||
hdrs = ["tf_status.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_file_statistics",
|
||||
hdrs = ["tf_file_statistics.h"],
|
||||
|
@ -231,7 +231,14 @@ cc_library(
|
||||
srcs = ["shape_inference_helpers.cc"],
|
||||
hdrs = ["shape_inference_helpers.h"],
|
||||
visibility = [":friends"],
|
||||
deps = ["//tensorflow/core:graph"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:graph",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
# Internal targets below this point.
|
||||
|
@ -880,11 +880,8 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
||||
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
||||
|
||||
// Parses output_arrays_order from command line option.
|
||||
absl::flat_hash_set<std::string> output_set;
|
||||
std::vector<std::string> output_arrays_order;
|
||||
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &output_set,
|
||||
&output_arrays_order)
|
||||
.ok()) {
|
||||
std::vector<std::string> outputs;
|
||||
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) {
|
||||
return emitError(loc, "parsing output array info failed ")
|
||||
<< output_arrays_string,
|
||||
nullptr;
|
||||
@ -892,7 +889,7 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
||||
|
||||
return tflite::FlatBufferToMlir(
|
||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||
context, loc, output_arrays_order);
|
||||
context, loc, outputs);
|
||||
}
|
||||
|
||||
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
||||
|
@ -384,7 +384,7 @@ class Translator {
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
|
||||
|
||||
// Returns opcode index for op identified by the op_name, if already
|
||||
// available. Otherwise, creates a new OperactorCode using the given `builtin`
|
||||
// available. Otherwise, creates a new OperatorCode using the given `builtin`
|
||||
// operator and associates it with `op_name`.
|
||||
uint32_t GetOpcodeIndex(const std::string& op_name,
|
||||
tflite::BuiltinOperator builtin);
|
||||
|
@ -720,7 +720,8 @@ static LogicalResult Verify(PackOp op) {
|
||||
for (Value *operand : op.getOperands()) {
|
||||
auto other_type = operand->getType().cast<ShapedType>();
|
||||
if (input_type != other_type)
|
||||
return op.emitOpError("operands should be of the same type");
|
||||
return op.emitOpError("operands should be of the same type. got ")
|
||||
<< input_type << ", " << other_type;
|
||||
}
|
||||
|
||||
return success();
|
||||
@ -857,8 +858,8 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
//
|
||||
// => Value [5, 8, 9]
|
||||
// TODO(b/133341698): Move to tablegen when variadic is supported.
|
||||
struct RemoveRedunantUnpackPack : public RewritePattern {
|
||||
explicit RemoveRedunantUnpackPack(MLIRContext *context)
|
||||
struct RemoveRedundantUnpackPack : public RewritePattern {
|
||||
explicit RemoveRedundantUnpackPack(MLIRContext *context)
|
||||
: RewritePattern(PackOp::getOperationName(), 2, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
@ -896,7 +897,7 @@ struct RemoveRedunantUnpackPack : public RewritePattern {
|
||||
|
||||
void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<RemoveRedunantUnpackPack>(context);
|
||||
results.insert<RemoveRedundantUnpackPack>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1041,7 +1042,7 @@ struct DropFakeQuant : public RewritePattern {
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
// Replace the matched FakeQuantOp by its primiary operand.
|
||||
// Replace the matched FakeQuantOp by its primary operand.
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
}
|
||||
};
|
||||
|
@ -32,7 +32,7 @@ def TFL_Dialect : Dialect {
|
||||
Invariants:
|
||||
|
||||
* All values are of Tensor type (in particular, scalars are
|
||||
represented using zero-dimentional tensors);
|
||||
represented using zero-dimensional tensors);
|
||||
}];
|
||||
|
||||
let cppNamespace = "TFL";
|
||||
@ -603,7 +603,7 @@ def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
|
||||
def TFL_FCWO_Shuffled4x16i8 : StrEnumAttrCase<"SHUFFLED4x16INT8">;
|
||||
|
||||
def TFL_FullyConnectedOptionsWeightFormatAttr :
|
||||
StrEnumAttr<"FullyConectedOptionsWeightsFormat",
|
||||
StrEnumAttr<"FullyConnectedOptionsWeightsFormat",
|
||||
"fully connected options weights format", [
|
||||
TFL_FCWO_Default, TFL_FCWO_Shuffled4x16i8
|
||||
]>;
|
||||
@ -1873,9 +1873,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
||||
x -> max(0, x)
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$x);
|
||||
let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
|
||||
|
||||
let results = (outs AnyTensor:$y);
|
||||
let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
|
||||
}
|
||||
|
||||
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||
@ -1888,9 +1888,24 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||
x -> max(0, min(6, x))
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$x);
|
||||
let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
|
||||
|
||||
let results = (outs AnyTensor:$y);
|
||||
let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
|
||||
}
|
||||
|
||||
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Relu1 operator";
|
||||
|
||||
let description = [{
|
||||
Element-wise Relu1 operator
|
||||
x -> max(-1, min(1, x))
|
||||
}];
|
||||
|
||||
let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
|
||||
|
||||
let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
|
||||
}
|
||||
|
||||
def TFL_ReshapeOp: TFL_Op<"reshape", [
|
||||
|
@ -237,8 +237,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
// Parse output arrays.
|
||||
std::vector<string> output_arrays(model_flags.output_arrays().begin(),
|
||||
model_flags.output_arrays().end());
|
||||
TF_RETURN_IF_ERROR(tensorflow::ParseOutputArrayInfo(
|
||||
output_arrays, &specs.output_arrays, &specs.output_arrays_order));
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
|
||||
|
||||
// Other flags.
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
|
@ -82,7 +82,7 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
|
||||
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
||||
}
|
||||
|
||||
// A method to retrive the name for the given op.
|
||||
// A method to retrieve the name for the given op.
|
||||
OperationToName op_to_name_;
|
||||
|
||||
// We split the normal names and regex names, since the former can use hash
|
||||
|
@ -102,7 +102,7 @@ class AffineOpCoefficient<int dim, int index> : NativeOpTrait<
|
||||
StrJoinInt<[dim, index]>.result,
|
||||
">::Impl")>;
|
||||
|
||||
// Specify this trait if the op doesn't have quantizable ouput. We shouldn't
|
||||
// Specify this trait if the op doesn't have quantizable output. We shouldn't
|
||||
// apply quantization on this op.
|
||||
def NoQuantizableResult : NativeOpTrait<"quant::NoQuantizableResult">;
|
||||
|
||||
|
@ -54,7 +54,7 @@ class SameOperandsAndResultsScale
|
||||
// OpTrait::quant::FixedResultUniformScale<
|
||||
// 8, -128, 390625, -8, 0, 255, false>::Impl> {
|
||||
//
|
||||
// TODO(fengliuai): create a better way to epxress floating point scale in the
|
||||
// TODO(fengliuai): create a better way to express floating point scale in the
|
||||
// template argument list.
|
||||
template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp,
|
||||
int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign>
|
||||
|
@ -133,10 +133,10 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
// quantization parameters are annotated by the Q/DQ op pairs. Each
|
||||
// matched pattern are rewritten by its quantized alternatives.
|
||||
//
|
||||
// The concret pattern, extends from this base pattern, can specify whether it
|
||||
// The concrete pattern, extends from this base pattern, can specify whether it
|
||||
// allows "hybrid" operands or results. These "hybrid" operands and results
|
||||
// don't have quantization parameters propagated to, so will be in float in the
|
||||
// quantized results. The concret pattern should define the following two
|
||||
// quantized results. The concrete pattern should define the following two
|
||||
// functions:
|
||||
//
|
||||
// bool AllowHybridOperand() const
|
||||
|
@ -114,8 +114,8 @@ func @fakequant_notdropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @RemoveRedunantUnpackPack
|
||||
func @RemoveRedunantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
||||
// CHECK-LABEL: @RemoveRedundantUnpackPack
|
||||
func @RemoveRedundantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
||||
%0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
|
||||
%1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>)
|
||||
return %1: tensor<2x5xf32>
|
||||
@ -125,8 +125,8 @@ func @RemoveRedunantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @RemoveRedunantPack
|
||||
func @RemoveRedunantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5xf32>) {
|
||||
// CHECK-LABEL: @RemoveRedundantPack
|
||||
func @RemoveRedundantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5xf32>) {
|
||||
%0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
|
||||
%1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>)
|
||||
return %1, %0#0: tensor<2x5xf32>, tensor<5xf32>
|
||||
|
@ -106,8 +106,8 @@ func @extractStackInputOutputOphint() {
|
||||
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||
// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
||||
// CHECK-DAG: %[[OUPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK-DAG: %[[OUPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK-DAG: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
// CHECK-DAG: %[[OUTPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
|
||||
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
|
||||
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||
|
@ -1,22 +1,22 @@
|
||||
// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure
|
||||
|
||||
func @addRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%1 = "tf.Add"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
%3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
%4 = "tf.Add"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
%6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%7 = "tf.Relu6"(%6) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
return %7: tensor<1xi32>
|
||||
func @addRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "tf.Add"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%4 = "tf.Add"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%7 = "tf.Relu6"(%6) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %7: tensor<1xf32>
|
||||
|
||||
// CHECK-LABEL: addRelu
|
||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
|
||||
// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
|
||||
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
|
||||
// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xi32>
|
||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
|
||||
// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
|
||||
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
|
||||
// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xf32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
@ -244,32 +244,32 @@ func @zeros_like(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
// CHECK: "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @divRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%1 = "tf.Div"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
%3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
%4 = "tf.Div"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
return %5: tensor<1xi32>
|
||||
func @divRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "tf.Div"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%4 = "tf.Div"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %5: tensor<1xf32>
|
||||
|
||||
// CHECK-LABEL: divRelu
|
||||
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
|
||||
// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
|
||||
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
|
||||
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
|
||||
// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
|
||||
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @squaredDifferenceRelu(tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> {
|
||||
^bb0(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>):
|
||||
%0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%1 = "tf.Relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
return %1: tensor<1xi32>
|
||||
func @squaredDifferenceRelu(tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> {
|
||||
^bb0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>):
|
||||
%0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "tf.Relu6"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %1: tensor<1xf32>
|
||||
|
||||
// CHECK-LABEL: squaredDifferenceRelu
|
||||
// CHECK: tfl.squared_difference %arg0, %arg1 : tensor<1xi32>
|
||||
// CHECK: %1 = "tfl.relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: tfl.squared_difference %arg0, %arg1 : tensor<1xf32>
|
||||
// CHECK: %1 = "tfl.relu6"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
|
@ -434,7 +434,7 @@ func @testEluI32(%arg0: tensor<? x i32>) -> tensor<? x i32> {
|
||||
|
||||
// -----
|
||||
|
||||
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
func @testFusedActivationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
// CHECK: "NONE"
|
||||
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4xi32>
|
||||
// CHECK: "RELU"
|
||||
@ -452,7 +452,7 @@ func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -
|
||||
|
||||
// -----
|
||||
|
||||
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
func @testFusedActivationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// expected-error @+1 {{attribute 'fused_activation_function' failed to satisfy constraint: fused activation enum}}
|
||||
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "Relu6"} : tensor<4xi32>
|
||||
return %0: tensor<4xi32>
|
||||
@ -1047,7 +1047,7 @@ func @testConcatInvalidOperandDimSize(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3x
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcatInvalidOperandDimSizeComaredToPrevInput(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3xi32>) -> tensor<?x?xi32> {
|
||||
func @testConcatInvalidOperandDimSizeComparedToPrevInput(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3xi32>) -> tensor<?x?xi32> {
|
||||
// expected-error @+1 {{'tfl.concatenation' op dimension size of dimension #1 of operand #1 must be equal to dimension size of dimension #1 of operand #0, expected 2, got 3}}
|
||||
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x3xi32>) -> tensor<?x?xi32>
|
||||
return %0 : tensor<?x?xi32>
|
||||
|
@ -278,7 +278,7 @@ func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x
|
||||
%cst2 = constant dense<3.0> : tensor<112x2xf32>
|
||||
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
// We cannot fuse this tfl.mul into the preceding conv op becuase %cst2 is not broadcast-compatible to %cst0.
|
||||
// We cannot fuse this tfl.mul into the preceding conv op because %cst2 is not broadcast-compatible to %cst0.
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
|
||||
|
||||
return %1 : tensor<1x112x112x2xf32>
|
||||
@ -600,3 +600,25 @@ func @squeezeToReshape(%arg0: tensor<1x1x2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor<1x1x2xf32>, tensor<1xi32>) -> tensor<2xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Relu1
|
||||
func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%cst = constant dense<-1.0> : tensor<f32>
|
||||
%cst1 = constant dense<1.0> : tensor<f32>
|
||||
%0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.minimum"(%0, %cst1) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
return %1 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Relu1_2
|
||||
func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%cst = constant dense<-1.0> : tensor<f32>
|
||||
%cst1 = constant dense<1.0> : tensor<f32>
|
||||
%0 = "tfl.minimum"(%arg0, %cst1) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.maximum"(%0, %cst) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
return %1 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
|
||||
opt<bool> use_splatted_constant(
|
||||
"use-splatted-constant",
|
||||
llvm::cl::desc(
|
||||
"Replace constants with randonmly generated splatted tensors"),
|
||||
"Replace constants with randomly generated splatted tensors"),
|
||||
llvm::cl::init(false), llvm::cl::Hidden);
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> input_mlir(
|
||||
|
@ -282,7 +282,7 @@ struct OphintCompositeOp {
|
||||
// Since we have different aggregation strategies, e.g., "first", "last",
|
||||
// "stack". We don't somehow aggregated to get the outputs for the funcOp.
|
||||
// This function is simply compute the RankedTensorType (shape & element type)
|
||||
std::map<int, Type> GetAggregatedOuputTypes(OpBuilder* builder) {
|
||||
std::map<int, Type> GetAggregatedOutputTypes(OpBuilder* builder) {
|
||||
std::map<int, Type> aggregated_output_types;
|
||||
for (const auto& kv : outputs) {
|
||||
const AggregatedOperand& operand = kv.second;
|
||||
@ -387,11 +387,12 @@ struct OphintCompositeOp {
|
||||
// inputs/outputs indicate edges) Assume the graph is acyclic. The preprocess
|
||||
// does the following:
|
||||
// Compute each operations's in-degress (how many input nodes they're taken)
|
||||
// Get all consumer operations for every operations. (operation_to_ouputs)
|
||||
// Get all consumer operations for every operations. (operation_to_outputs)
|
||||
// Get the init_queue (those operations will be processed first).
|
||||
void PreprocessTopoSortGraph(
|
||||
Block* block, std::queue<Operation*>* init_queue,
|
||||
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>>* operation_to_ouputs,
|
||||
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>>*
|
||||
operation_to_outputs,
|
||||
llvm::DenseMap<Operation*, int>* operation_to_in_degrees) {
|
||||
for (auto& op : *block) {
|
||||
if (&op == block->getTerminator()) continue;
|
||||
@ -412,9 +413,9 @@ void PreprocessTopoSortGraph(
|
||||
}
|
||||
operation_to_in_degrees->try_emplace(&op, input_ops.size());
|
||||
for (auto* input_op : input_ops) {
|
||||
auto preceeding_op_it = operation_to_ouputs->find(input_op);
|
||||
if (preceeding_op_it == operation_to_ouputs->end()) {
|
||||
auto result = operation_to_ouputs->try_emplace(
|
||||
auto preceeding_op_it = operation_to_outputs->find(input_op);
|
||||
if (preceeding_op_it == operation_to_outputs->end()) {
|
||||
auto result = operation_to_outputs->try_emplace(
|
||||
input_op, llvm::DenseSet<Operation*>());
|
||||
preceeding_op_it = result.first;
|
||||
}
|
||||
@ -442,19 +443,19 @@ bool IsSideEffectOp(Operation* op) {
|
||||
// Also assume the block has no arguments.
|
||||
LogicalResult TopoSortOperations(OpBuilder* builder) {
|
||||
std::queue<Operation*> init_queue;
|
||||
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>> operation_to_ouputs;
|
||||
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>> operation_to_outputs;
|
||||
llvm::DenseMap<Operation*, int> operation_to_in_degrees;
|
||||
std::vector<Operation*> sorted_ops;
|
||||
|
||||
PreprocessTopoSortGraph(builder->getBlock(), &init_queue,
|
||||
&operation_to_ouputs, &operation_to_in_degrees);
|
||||
&operation_to_outputs, &operation_to_in_degrees);
|
||||
while (!init_queue.empty()) {
|
||||
Operation* current_op = init_queue.front();
|
||||
init_queue.pop();
|
||||
sorted_ops.push_back(current_op);
|
||||
|
||||
auto current_op_to_output_it = operation_to_ouputs.find(current_op);
|
||||
if (current_op_to_output_it == operation_to_ouputs.end()) {
|
||||
auto current_op_to_output_it = operation_to_outputs.find(current_op);
|
||||
if (current_op_to_output_it == operation_to_outputs.end()) {
|
||||
continue;
|
||||
}
|
||||
for (Operation* output_op : current_op_to_output_it->second) {
|
||||
@ -467,7 +468,7 @@ LogicalResult TopoSortOperations(OpBuilder* builder) {
|
||||
operation_to_in_degrees.erase(output_op_it);
|
||||
}
|
||||
}
|
||||
operation_to_ouputs.erase(current_op_to_output_it);
|
||||
operation_to_outputs.erase(current_op_to_output_it);
|
||||
}
|
||||
|
||||
// Before we performs the sort. We need to make sure we didn't mess the
|
||||
@ -629,11 +630,11 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
||||
|
||||
// Step 4, get aggregated output types.
|
||||
const std::map<int, Type>& aggregated_output_types =
|
||||
ophint_composite_op.GetAggregatedOuputTypes(builder);
|
||||
ophint_composite_op.GetAggregatedOutputTypes(builder);
|
||||
|
||||
// Step 5, create & place the fused op and rewire the inputs.
|
||||
// Here we use a funcOp to represent the fused op. This "funcOp" will be
|
||||
// coonverted to other ops (like UnidirectionalSequenceRNNOp) in the
|
||||
// converted to other ops (like UnidirectionalSequenceRNNOp) in the
|
||||
// legalization phase.
|
||||
Operation* inserted_before_op = ophint_composite_op.GetFirstOutputOp();
|
||||
Operation* fused_op = BuildFusedFuncOp(
|
||||
|
@ -191,10 +191,10 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTfLiteFusedOpIfAvaiable(StringRef func_name,
|
||||
FuncOp composite_func_op,
|
||||
CallOp call_op,
|
||||
OpBuilder* builder) {
|
||||
LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name,
|
||||
FuncOp composite_func_op,
|
||||
CallOp call_op,
|
||||
OpBuilder* builder) {
|
||||
Operation* fused_op = nullptr;
|
||||
if (func_name == kUnidirectionalSequenceRnn) {
|
||||
// TODO(renjieliu): Validate the func op inputs.
|
||||
@ -243,8 +243,8 @@ LogicalResult ConvertCallOps(llvm::StringMap<FuncOp>* composite_func_ops,
|
||||
StringRef func_name = composite_func_op.getAttr(kTfLiteFunctionName)
|
||||
.cast<StringAttr>()
|
||||
.getValue();
|
||||
if (failed(ConvertTfLiteFusedOpIfAvaiable(func_name, composite_func_op,
|
||||
call_op, &builder)))
|
||||
if (failed(ConvertTfLiteFusedOpIfAvailable(func_name, composite_func_op,
|
||||
call_op, &builder)))
|
||||
return failure();
|
||||
|
||||
composite_func_ops->erase(it);
|
||||
|
@ -140,6 +140,8 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
|
||||
return ExpandTo4DForConvImpl(a, true);
|
||||
}
|
||||
|
||||
// Returns shape of a ranked tensor.
|
||||
// Precondition: output_val's is ranked tensor.
|
||||
DenseElementsAttr GetShape(Value *output_val) {
|
||||
auto output_type = output_val->getType().cast<RankedTensorType>();
|
||||
auto shape_vector = output_type.getShape();
|
||||
|
@ -267,9 +267,27 @@ multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
|
||||
foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
|
||||
in defm : FuseTileBroadcastIntoFollowingBinary<BroadcastingOp>;
|
||||
|
||||
// Returns shape of a ranked tensor.
|
||||
// if called without a ranked tensor it will fail.
|
||||
def GetShape: NativeCodeCall<"GetShape($0)">;
|
||||
|
||||
def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
|
||||
(TFL_ReshapeOp $input,
|
||||
(ConstantOp (GetShape $squeeze_op))),
|
||||
[(AnyStaticShapeTensor $squeeze_op)]>;
|
||||
|
||||
class ValueEquals<string val> : Constraint<CPred<
|
||||
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
|
||||
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
|
||||
|
||||
def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
|
||||
(ConstantOp $NegOne)),
|
||||
(ConstantOp $One)),
|
||||
(TFL_Relu1Op $input),
|
||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||
|
||||
def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
||||
(ConstantOp $One)),
|
||||
(ConstantOp $NegOne)),
|
||||
(TFL_Relu1Op $input),
|
||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||
|
@ -68,11 +68,11 @@ class ConvertEmbeddedLookupFunc {
|
||||
if (func_.getNumArguments() != 2) {
|
||||
return func_.emitError()
|
||||
<< "Invalid number of arguments in the embedding "
|
||||
"matmal composite function";
|
||||
"matmul composite function";
|
||||
}
|
||||
if (func_.getType().getNumResults() != 1) {
|
||||
return func_.emitError() << "Invalid number of results in the embedding "
|
||||
"matmal composite function";
|
||||
"matmul composite function";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ limitations under the License.
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<std::string> quantize_whitelist(
|
||||
"tfl-test-quantize-whitelist", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma seprarated list of whitelisted functions to be "
|
||||
llvm::cl::desc("comma separated list of whitelisted functions to be "
|
||||
"quantized. Only used in tests"),
|
||||
llvm::cl::CommaSeparated);
|
||||
|
||||
|
@ -400,7 +400,7 @@ class ConvertTFDepthwiseConv2dNative
|
||||
}
|
||||
};
|
||||
|
||||
// StridedSlice can have complicated atributes like begin_axis_mask,
|
||||
// StridedSlice can have complicated attributes like begin_axis_mask,
|
||||
// end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
|
||||
// masks will complicate the strided_slice computation logic, we can simplify
|
||||
// the logic by inserting a reshape op to pad the inputs so strided_slice can
|
||||
|
@ -247,7 +247,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
||||
}
|
||||
|
||||
if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
|
||||
// Input dimensions must be compatible for multipication.
|
||||
// Input dimensions must be compatible for multiplication.
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,7 @@ class ConvertLSTMCellSimpleToFusedLSTM {
|
||||
Value* input2cell_;
|
||||
Value* input2output_;
|
||||
|
||||
// reccurrent -> cifg
|
||||
// recurrent -> cifg
|
||||
Value* rec2input_;
|
||||
Value* rec2forget_;
|
||||
Value* rec2cell_;
|
||||
|
@ -28,7 +28,7 @@ config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm')
|
||||
config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR'])
|
||||
config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'],
|
||||
'local_config_mlir')
|
||||
# TODO(jpienaar): Replace with sufffices in build rule.
|
||||
# TODO(jpienaar): Replace with suffices in build rule.
|
||||
config.suffixes = ['.td', '.mlir', '.pbtxt']
|
||||
|
||||
mlir_tf_tools_dirs = [
|
||||
|
@ -217,6 +217,7 @@ cc_library(
|
||||
"transforms/shape_inference.cc",
|
||||
"transforms/shape_inference_pass.cc",
|
||||
"transforms/sink_constant.cc",
|
||||
"transforms/test_side_effect_analysis.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
"transforms/tpu_merge_variables_with_execute.cc",
|
||||
"transforms/tpu_rewrite_pass.cc",
|
||||
@ -239,6 +240,7 @@ cc_library(
|
||||
":error_util",
|
||||
":export_tf_dialect_op",
|
||||
":mangling_util",
|
||||
":side_effect_analysis",
|
||||
":tensorflow",
|
||||
":tensorflow_optimize_inc_gen",
|
||||
":tpu_rewrite_device_util",
|
||||
@ -467,6 +469,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:types",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
@ -669,6 +672,7 @@ cc_library(
|
||||
":import_utils",
|
||||
":mangling_util",
|
||||
":mlir_roundtrip_flags",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -981,3 +985,19 @@ cc_library(
|
||||
"@local_config_mlir//:Pass",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "side_effect_analysis",
|
||||
srcs = ["analysis/side_effect_analysis.cc"],
|
||||
hdrs = ["analysis/side_effect_analysis.h"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,374 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int64_t kUnknownResourceId = -1;
|
||||
|
||||
// Returns if a VarHandleOp is anonymous, which means it always creates a new
|
||||
// variable.
|
||||
bool IsResourceHandleAnonymous(TF::VarHandleOp handle) {
|
||||
return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME;
|
||||
}
|
||||
|
||||
// Returns a string unique identifier for a non-anonymous VarHandleOp.
|
||||
std::string GetVarHandleStringId(TF::VarHandleOp handle) {
|
||||
auto device = handle.getAttrOfType<StringAttr>("device");
|
||||
return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(),
|
||||
"/", device ? device.getValue().str() : std::string(""));
|
||||
}
|
||||
|
||||
// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always
|
||||
// creates a new ID; otherwise, tries to reuse the existing ID for the
|
||||
// referenced variable if it exists, or creates a new one if not.
|
||||
int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id,
|
||||
llvm::StringMap<int64_t>* name_id_map) {
|
||||
// Always create a new ID for anonymous handle.
|
||||
if (IsResourceHandleAnonymous(handle)) return (*next_id)++;
|
||||
|
||||
auto name = GetVarHandleStringId(handle);
|
||||
auto emplace_res = name_id_map->try_emplace(name, *next_id);
|
||||
// New ID created, increment next_id.
|
||||
if (emplace_res.second) ++(*next_id);
|
||||
return emplace_res.first->second;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) {
|
||||
auto func_op = llvm::dyn_cast<FuncOp>(op);
|
||||
if (!func_op) return;
|
||||
AnalyzeFunction(func_op);
|
||||
}
|
||||
|
||||
void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
||||
// This function populates resource_value_to_ids_.
|
||||
//
|
||||
// TODO(yuanzx): Pass variable aliasing information to functions so we can
|
||||
// properly resolve aliasing arguments.
|
||||
//
|
||||
// Before having that, we assume function arguments do not alias each other.
|
||||
int64_t next_unique_id = 0;
|
||||
for (auto arg : func_op.getArguments()) {
|
||||
if (!mlir::getElementTypeOrSelf(arg->getType()).isa<TF::ResourceType>())
|
||||
continue;
|
||||
resource_value_to_ids_[arg].insert(next_unique_id++);
|
||||
}
|
||||
llvm::StringMap<int64_t> var_handle_name_id_map;
|
||||
auto forward_input_to_output = [&](Value* operand, Value* result) {
|
||||
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
|
||||
return;
|
||||
auto operand_it = resource_value_to_ids_.find(operand);
|
||||
assert(operand_it != resource_value_to_ids_.end() &&
|
||||
"A resource-type output does not have the corresponding "
|
||||
"resource-type input.");
|
||||
resource_value_to_ids_[result].insert(operand_it->getSecond().begin(),
|
||||
operand_it->getSecond().end());
|
||||
};
|
||||
// TODO(yuanzx): Consider control-flow ops.
|
||||
func_op.walk([&](Operation* op) {
|
||||
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
|
||||
resource_value_to_ids_[var_handle.resource()].insert(
|
||||
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
|
||||
&var_handle_name_id_map));
|
||||
} else if (llvm::isa<TF::IdentityNOp>(op) ||
|
||||
llvm::isa<TF::IdentityOp>(op)) {
|
||||
for (auto operand_and_result :
|
||||
llvm::zip(op->getOperands(), op->getResults())) {
|
||||
forward_input_to_output(std::get<0>(operand_and_result),
|
||||
std::get<1>(operand_and_result));
|
||||
}
|
||||
} else {
|
||||
for (auto result : op->getResults()) {
|
||||
if (!mlir::getElementTypeOrSelf(result->getType())
|
||||
.isa<TF::ResourceType>())
|
||||
continue;
|
||||
resource_value_to_ids_[result].insert(kUnknownResourceId);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
bool ResourceAliasAnalysis::IsUnknownResource(const Value* resource) const {
|
||||
auto it = resource_value_to_ids_.find(resource);
|
||||
assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
|
||||
// The set is sorted so we only need to check the first element since
|
||||
// kUnknownResourceId < 0.
|
||||
static_assert(kUnknownResourceId < 0,
|
||||
"kUnknownResourceId should be negative");
|
||||
return *it->getSecond().begin() == kUnknownResourceId;
|
||||
}
|
||||
|
||||
const llvm::SmallSet<int64_t, 8>& ResourceAliasAnalysis::GetResourceUniqueIds(
|
||||
const Value* resource) const {
|
||||
auto it = resource_value_to_ids_.find(resource);
|
||||
assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
|
||||
return it->getSecond();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns a set that contains only kUnknownResourceId.
|
||||
llvm::SmallDenseSet<int64_t, 8> UnknownResourceSet() {
|
||||
llvm::SmallDenseSet<int64_t, 8> unknown_set;
|
||||
unknown_set.insert(kUnknownResourceId);
|
||||
return unknown_set;
|
||||
}
|
||||
|
||||
// Returns all resources that could be accessed by op, or UnknownResourceSet()
|
||||
// if we cannot find all of them.
|
||||
llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
|
||||
Operation* op, const ResourceAliasAnalysis& alias_analysis) {
|
||||
llvm::SmallDenseSet<int64_t, 8> resources;
|
||||
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (!mlir::getElementTypeOrSelf(operand->getType()).isa<TF::ResourceType>())
|
||||
continue;
|
||||
if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet();
|
||||
const auto& ids = alias_analysis.GetResourceUniqueIds(operand);
|
||||
resources.insert(ids.begin(), ids.end());
|
||||
}
|
||||
for (auto result : op->getResults()) {
|
||||
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
|
||||
continue;
|
||||
if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet();
|
||||
const auto& ids = alias_analysis.GetResourceUniqueIds(result);
|
||||
resources.insert(ids.begin(), ids.end());
|
||||
}
|
||||
return resources;
|
||||
}
|
||||
|
||||
// Returns an XlaResourceOpInfo (or nullptr if it does not exist) that specifies
|
||||
// the resource access type of the op. It tells whether the op is read only,
|
||||
// etc.
|
||||
//
|
||||
// TODO(yuanzx): Define this information in a different place. Currently we use
|
||||
// tensorflow/compiler/tf2xla/resource_operation_table.h.
|
||||
const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) {
|
||||
auto op_name = op->getName().getStringRef().str();
|
||||
if (op->getName().getDialect() !=
|
||||
TF::TensorFlowDialect::getDialectNamespace()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::GetResourceOpInfoForOp(
|
||||
op->getName().getStringRef().split('.').second.str());
|
||||
}
|
||||
|
||||
// Returns whether `op` accesses resources and it is known to be read-only.
|
||||
bool OpIsReadOnly(Operation* op) {
|
||||
auto resource_op_info = GetResourceInfoForOp(op);
|
||||
return resource_op_info &&
|
||||
resource_op_info->kind() == tensorflow::XlaResourceOpKind::kRead;
|
||||
}
|
||||
|
||||
// Returns if `op` is a resource declaration.
|
||||
bool OpIsDeclaration(Operation* op,
|
||||
const ResourceAliasAnalysis& alias_analysis) {
|
||||
// TODO(yuanzx): Add other types of resources.
|
||||
return llvm::isa<TF::VarHandleOp>(op) ||
|
||||
((llvm::isa<TF::IdentityNOp>(op) || llvm::isa<TF::IdentityOp>(op)) &&
|
||||
!FindAccessedResources(op, alias_analysis).empty());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op,
|
||||
bool read_only) {
|
||||
if (resource_id == kUnknownResourceId) {
|
||||
if (read_only) {
|
||||
// New unknown read is not tracked by any known resource access.
|
||||
for (auto& entry : per_resource_access_info_) {
|
||||
entry.getSecond().tracked_last_unknown_read = false;
|
||||
}
|
||||
} else {
|
||||
// Unknown write can clear all other tracked information, since it acts
|
||||
// like a barrier.
|
||||
per_resource_access_info_.clear();
|
||||
}
|
||||
}
|
||||
auto& info = per_resource_access_info_[resource_id];
|
||||
if (read_only) {
|
||||
info.reads_since_last_write.push_back(op);
|
||||
// Resource read must have carried control dependencies of unknown write.
|
||||
info.tracked_last_unknown_write = true;
|
||||
} else {
|
||||
// Resource write must have carried control dependencies of unknown access.
|
||||
info.tracked_last_unknown_write = true;
|
||||
info.tracked_last_unknown_read = true;
|
||||
info.last_write = op;
|
||||
info.reads_since_last_write.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id,
|
||||
Operation* op,
|
||||
bool read_only) {
|
||||
auto it = per_resource_access_info_.find(resource_id);
|
||||
if (it == per_resource_access_info_.end()) return;
|
||||
const auto& access_info = it->getSecond();
|
||||
auto& control_predecessors = control_predecessors_[op];
|
||||
bool read_tracked = false;
|
||||
if (!read_only) {
|
||||
control_predecessors.insert(access_info.reads_since_last_write.begin(),
|
||||
access_info.reads_since_last_write.end());
|
||||
read_tracked = !access_info.reads_since_last_write.empty();
|
||||
}
|
||||
if (access_info.last_write && !read_tracked) {
|
||||
control_predecessors.insert(access_info.last_write);
|
||||
}
|
||||
}
|
||||
|
||||
void SideEffectAnalysis::AnalyzeFunction(
|
||||
FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) {
|
||||
// This function populates control_predecessors_ and control_successors_ by
|
||||
// walking through func_op's body, and tracking resource accesses in
|
||||
// per_resource_access_info_.
|
||||
|
||||
// Returns whether an access to `resource` can skip control edges from
|
||||
// prevoius accesses to unknown resources, due to that earlier accesses to
|
||||
// `resource` already indirectly tracked previous accesses to uknown
|
||||
// resources. `read_only` specifies the type of access of the current op being
|
||||
// considered.
|
||||
auto unknown_access_indirectly_tracked_by_resource = [&](int64_t resource,
|
||||
bool read_only) {
|
||||
auto it = per_resource_access_info_.find(resource);
|
||||
if (it == per_resource_access_info_.end()) return false;
|
||||
auto unknown_it = per_resource_access_info_.find(kUnknownResourceId);
|
||||
const bool no_unknown_read =
|
||||
unknown_it == per_resource_access_info_.end() ||
|
||||
unknown_it->getSecond().reads_since_last_write.empty();
|
||||
return read_only
|
||||
? it->second.tracked_last_unknown_write
|
||||
: it->second.tracked_last_unknown_write &&
|
||||
(it->second.tracked_last_unknown_read || no_unknown_read);
|
||||
};
|
||||
|
||||
func_op.walk([&](Operation* op) {
|
||||
// We do not need explicit control edges for declaration ops.
|
||||
if (OpIsDeclaration(op, alias_analysis)) return;
|
||||
|
||||
auto resource_op_info = GetResourceInfoForOp(op);
|
||||
if (!resource_op_info && op->hasNoSideEffect()) return;
|
||||
|
||||
llvm::SmallDenseSet<int64_t, 8> resources =
|
||||
resource_op_info ? FindAccessedResources(op, alias_analysis)
|
||||
: UnknownResourceSet();
|
||||
assert(!resources.empty());
|
||||
const bool is_unknown = resources.count(kUnknownResourceId) > 0;
|
||||
const bool read_only = OpIsReadOnly(op);
|
||||
bool indirectly_tracked_unknown_access = false;
|
||||
// First add edges from known resources.
|
||||
if (is_unknown) {
|
||||
for (auto& entry : per_resource_access_info_) {
|
||||
if (entry.getFirst() == kUnknownResourceId) continue;
|
||||
AddPredecessorsForAccess(entry.getFirst(), op, read_only);
|
||||
indirectly_tracked_unknown_access |=
|
||||
unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
|
||||
read_only);
|
||||
}
|
||||
} else {
|
||||
for (int64_t resource : resources) {
|
||||
AddPredecessorsForAccess(resource, op, read_only);
|
||||
indirectly_tracked_unknown_access |=
|
||||
unknown_access_indirectly_tracked_by_resource(resource, read_only);
|
||||
// Update access info for known resources.
|
||||
TrackAccess(resource, op, read_only);
|
||||
}
|
||||
}
|
||||
// If not indirectly tracked, add edges from the unknown resource.
|
||||
if (!indirectly_tracked_unknown_access) {
|
||||
AddPredecessorsForAccess(kUnknownResourceId, op, read_only);
|
||||
}
|
||||
if (is_unknown) {
|
||||
// Update access info for unknown resource.
|
||||
TrackAccess(kUnknownResourceId, op, read_only);
|
||||
}
|
||||
});
|
||||
|
||||
// Populate control_successors_ based on control_predecessors_.
|
||||
for (auto& entry : control_predecessors_) {
|
||||
auto op = entry.getFirst();
|
||||
for (auto predecessor : entry.getSecond()) {
|
||||
control_successors_[predecessor].insert(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<Operation*, 8> SideEffectAnalysis::DirectControlPredecessors(
|
||||
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
|
||||
llvm::SmallVector<Operation*, 8> result;
|
||||
auto it = control_predecessors_.find(op);
|
||||
if (it == control_predecessors_.end()) return result;
|
||||
result.reserve(it->getSecond().size());
|
||||
for (auto predecessor : it->getSecond()) {
|
||||
if (!filter || filter(predecessor)) result.push_back(predecessor);
|
||||
}
|
||||
llvm::sort(result,
|
||||
[](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Operation*, 8> SideEffectAnalysis::DirectControlSuccessors(
|
||||
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
|
||||
llvm::SmallVector<Operation*, 8> result;
|
||||
auto it = control_successors_.find(op);
|
||||
if (it == control_successors_.end()) return result;
|
||||
result.reserve(it->getSecond().size());
|
||||
for (auto successor : it->getSecond()) {
|
||||
if (!filter || filter(successor)) result.push_back(successor);
|
||||
}
|
||||
llvm::sort(result,
|
||||
[](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
|
||||
return result;
|
||||
}
|
||||
|
||||
SideEffectAnalysis::SideEffectAnalysis(Operation* op) {
|
||||
auto func_op = llvm::dyn_cast<FuncOp>(op);
|
||||
if (!func_op) return;
|
||||
ResourceAliasAnalysis alias_analysis(op);
|
||||
AnalyzeFunction(func_op, alias_analysis);
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -0,0 +1,125 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Region.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
// An analysis that runs on a function and maps each resource-type value to a
|
||||
// set of unique int64_t IDs representing the possible resources it could alias.
|
||||
class ResourceAliasAnalysis {
|
||||
public:
|
||||
explicit ResourceAliasAnalysis(Operation* op);
|
||||
~ResourceAliasAnalysis() = default;
|
||||
ResourceAliasAnalysis(ResourceAliasAnalysis&&) = default;
|
||||
|
||||
// Returns if the analysis fails to resolve a resource-type value.
|
||||
bool IsUnknownResource(const Value* resource) const;
|
||||
|
||||
// Returns the set unique IDs which `resource` could alias. Requires that
|
||||
// IsUnknownResource(resource) == true.
|
||||
const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds(
|
||||
const Value* resource) const;
|
||||
|
||||
private:
|
||||
ResourceAliasAnalysis() = default;
|
||||
|
||||
// Runs the analysis on `func_op` and populates resource_value_to_ids_.
|
||||
void AnalyzeFunction(FuncOp func_op);
|
||||
|
||||
// Maps each resource-type value to a set of unique IDs that it could alias.
|
||||
llvm::SmallDenseMap<const Value*, llvm::SmallSet<int64_t, 8>, 8>
|
||||
resource_value_to_ids_;
|
||||
};
|
||||
|
||||
// An analysis that runs on a function and infers the control predecessors and
|
||||
// successors for each op, based on side-effects on known and unknown resources.
|
||||
// Side-effecting ops on uknown resources are conservatively treated as
|
||||
// interfering with all known resource op accesses. It distinguishes accesses
|
||||
// based on whether they are read-only, and read-only ops do not interfer with
|
||||
// each other.
|
||||
class SideEffectAnalysis {
|
||||
public:
|
||||
explicit SideEffectAnalysis(Operation* op);
|
||||
SideEffectAnalysis(SideEffectAnalysis&& other) = default;
|
||||
~SideEffectAnalysis() = default;
|
||||
|
||||
// Returns a vector of ops that are direct control predecessors of `op`,
|
||||
// sorted in program order. If `filter` is provided, only predecessors that
|
||||
// pass the filter (returning true) will be included.
|
||||
llvm::SmallVector<Operation*, 8> DirectControlPredecessors(
|
||||
Operation* op,
|
||||
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
|
||||
|
||||
// Returns a vector of ops that are direct control successors of `op`, sorted
|
||||
// in program order. If `filter` is provided, only successors that pass the
|
||||
// filter (returning true) will be included.
|
||||
llvm::SmallVector<Operation*, 8> DirectControlSuccessors(
|
||||
Operation* op,
|
||||
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
|
||||
|
||||
private:
|
||||
// Runs the analysis on `func_op` and populates control_predecessors_ and
|
||||
// control_successors_.
|
||||
void AnalyzeFunction(FuncOp func_op,
|
||||
const ResourceAliasAnalysis& alias_analysis);
|
||||
|
||||
// Updates control_predecessors_ for `op` that is being visted, on the given
|
||||
// `resource_id`.
|
||||
void AddPredecessorsForAccess(int64_t resource_id, Operation* op,
|
||||
bool read_only);
|
||||
|
||||
// Adds op's access to per_resource_access_info_.
|
||||
void TrackAccess(int64_t resource_id, Operation* op, bool read_only);
|
||||
|
||||
// Maps from an op to its control predecessors.
|
||||
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 8>, 8>
|
||||
control_predecessors_;
|
||||
// Maps from an op to its control successors.
|
||||
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 8>, 8>
|
||||
control_successors_;
|
||||
|
||||
// Internal per-resource data structure when we build the dependencies.
|
||||
struct PerResourceAcessInfo {
|
||||
// Last op that writes the resource before the current op being analyzed.
|
||||
Operation* last_write = nullptr;
|
||||
// Read ops since last_write before the current op being analyzed.
|
||||
llvm::SmallVector<Operation*, 8> reads_since_last_write;
|
||||
// Whether previous accesses of this resource already tracked last unknown
|
||||
// read/write.
|
||||
bool tracked_last_unknown_read = false;
|
||||
bool tracked_last_unknown_write = false;
|
||||
};
|
||||
llvm::SmallDenseMap<int64_t, PerResourceAcessInfo, 8>
|
||||
per_resource_access_info_;
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
|
@ -26,7 +26,7 @@ static DialectRegistration<TFControlFlow::TFControlFlowDialect>
|
||||
tf_control_flow_ops;
|
||||
static DialectRegistration<TF::TensorFlowDialect> tf_ops;
|
||||
static DialectRegistration<tf_executor::TensorFlowExecutorDialect>
|
||||
tf_excutor_dialect;
|
||||
tf_executor_dialect;
|
||||
static DialectRegistration<tf_device::TensorFlowDeviceDialect>
|
||||
tf_device_dialect;
|
||||
static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect>
|
||||
|
@ -26,11 +26,13 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
|
@ -576,6 +576,44 @@ endian orderings will give different results.
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_BitwiseOrOp : TF_Op<"BitwiseOr", [Broadcastable, Commutative, NoSideEffect]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Elementwise computes the bitwise OR of `x` and `y`.";
|
||||
|
||||
let description = [{
|
||||
The result will have those bits set, that are set in `x`, `y` or both. The
|
||||
computation is performed on the underlying representations of `x` and `y`.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64,
|
||||
tf.uint8, tf.uint16, tf.uint32, tf.uint64]
|
||||
|
||||
for dtype in dtype_list:
|
||||
lhs = tf.constant([0, 5, 3, 14], dtype=dtype)
|
||||
rhs = tf.constant([5, 0, 7, 11], dtype=dtype)
|
||||
exp = tf.constant([5, 5, 7, 15], dtype=tf.float32)
|
||||
|
||||
res = bitwise_ops.bitwise_or(lhs, rhs)
|
||||
tf.assert_equal(tf.cast(res, tf.float32), exp) # TRUE
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntTensor:$x,
|
||||
TF_IntTensor:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_IntTensor:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Return the reduction indices for computing gradients of s0 op s1 with broadcast.
|
||||
@ -1320,6 +1358,10 @@ Comparison with `numpy.einsum`:
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_EluOp : TF_Op<"Elu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
@ -5726,6 +5768,8 @@ If two elements are equal, the lower-index element appears first.
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
|
||||
|
@ -39,7 +39,7 @@ This dialect maps to TensorFlow operations.
|
||||
Invariants:
|
||||
|
||||
* All values are of Tensor type (in particular, scalars are
|
||||
represented using zero-dimentional tensors);
|
||||
represented using zero-dimensional tensors);
|
||||
|
||||
TODO: Make invariants more structured so that we can reference them in ops.
|
||||
}];
|
||||
|
@ -513,7 +513,7 @@ void ConstOp::build(Builder *builder, OperationState &result, Attribute value) {
|
||||
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
|
||||
value.isa<IntegerAttr>()) {
|
||||
// All TensorFlow types must be tensor types. In the build() method,
|
||||
// we want to provide more flexiblity by allowing attributes of scalar
|
||||
// we want to provide more flexibility by allowing attributes of scalar
|
||||
// types. But we need to wrap it up with ElementsAttr to construct
|
||||
// valid TensorFlow constants.
|
||||
type = RankedTensorType::get(/*shape=*/{}, value.getType());
|
||||
@ -674,6 +674,21 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<DivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EinsumOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Verifies that,
|
||||
// * Arity of the op is at most two.
|
||||
//
|
||||
// TODO(hinsu): Verify einsum equation attribute.
|
||||
static LogicalResult Verify(EinsumOp op) {
|
||||
if (op.N() > 2) {
|
||||
return op.emitOpError("supports at most two operands");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EmptyTensorListOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1683,6 +1698,21 @@ static LogicalResult Verify(TensorListStackOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TopKV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(TopKV2Op op) {
|
||||
if (!HasRankAtLeast(op.input(), 1))
|
||||
return op.emitOpError(
|
||||
"requires input operand to have at least 1 dimension");
|
||||
|
||||
if (!IsOfRankOrUnranked(op.k(), 0))
|
||||
return op.emitOpError("requires k operand to be 0D tensor");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -41,7 +41,7 @@ class TensorFlowDialect : public Dialect {
|
||||
|
||||
static StringRef getDialectNamespace() { return "tf"; }
|
||||
|
||||
// Gradient attribute ("tf.gradient") in the list of NamedAttibutes in a
|
||||
// Gradient attribute ("tf.gradient") in the list of NamedAttributes in a
|
||||
// function references to its gradient function. This attribute in TensorFlow
|
||||
// Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
|
||||
// string description of gradient attribute.
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
|
@ -101,6 +101,17 @@ func @testDifferentCastType(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64
|
||||
// CHECK: return %0, %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testCompatibleCastType
|
||||
func @testCompatibleCastType(%arg0: tensor<?xf32>) -> (tensor<10xf32>, tensor<10xf32>) {
|
||||
%0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<?xf32>) -> tensor<10xf32>
|
||||
%1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<?xf32>) -> tensor<10xf32>
|
||||
return %0, %1: tensor<10xf32>, tensor<10xf32>
|
||||
|
||||
// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<?xf32>) -> tensor<10xf32>
|
||||
// CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<?xf32>) -> tensor<10xf32>
|
||||
// CHECK: return %0, %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSameCastTypeAcrossBasicBlocks
|
||||
func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> {
|
||||
^bb0(%arg0: tensor<8x16x32x64xf32>):
|
||||
|
@ -232,12 +232,12 @@ module {
|
||||
|
||||
// -----
|
||||
|
||||
// Single device with non-continous instructions in original block.
|
||||
// Single device with non-continuous instructions in original block.
|
||||
|
||||
module {
|
||||
// CHECK-LABEL: func @noncontinoussinglecluster
|
||||
// CHECK-LABEL: func @noncontinuoussinglecluster
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
|
||||
func @noncontinoussinglecluster(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
func @noncontinuoussinglecluster(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = tf_executor.graph {
|
||||
%1:2 = tf_executor.island {
|
||||
|
||||
|
@ -0,0 +1,237 @@
|
||||
// RUN: tf-opt -split-input-file -tf-test-side-effect-analysis -verify-diagnostics %s | FileCheck %s --dump-input=fail
|
||||
|
||||
// Tests that the pass tracks control dependencies for reads/writes on the same
|
||||
// resource.
|
||||
|
||||
// CHECK-LABEL: func @non_aliasing_reads_writes
|
||||
func @non_aliasing_reads_writes(
|
||||
// expected-remark@above {{ID: 13}}
|
||||
// expected-remark@above {{Predecessors: {12}}}
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<32xf32>) -> (tensor<32xf32>) {
|
||||
%graph = tf_executor.graph {
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
// expected-remark@above {{Successors: {12}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island:2 = tf_executor.island {
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
// expected-remark@above {{Successors: {10}}}
|
||||
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
// expected-remark@above {{Successors: {1}}}
|
||||
"tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 1}}
|
||||
// expected-remark@above {{Predecessors: {0}}}
|
||||
// expected-remark@above {{Successors: {6}}}
|
||||
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Successors: {5}}}
|
||||
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 3}}
|
||||
%read2 = "tf.ReadVariableOp"(%var_handle) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
"tf.AssignVariableOp"(%arg1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {2}}}
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
"tf.AssignVariableOp"(%arg0, %read2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
%read3 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {6}}}
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
tf_executor.yield %read3 : tensor<32xf32>
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Predecessors: {4,5,7}}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
}
|
||||
tf_executor.fetch %island#0 : tensor<32xf32>
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Predecessors: {9}}}
|
||||
// expected-remark@above {{Successors: {11}}}
|
||||
}
|
||||
return %graph : tensor<32xf32>
|
||||
// expected-remark@above {{ID: 12}}
|
||||
// expected-remark@above {{Predecessors: {11}}}
|
||||
// expected-remark@above {{Successors: {13}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass tracks control dependencies for reads/writes on the two
|
||||
// resource handles that refer to the same variable.
|
||||
|
||||
// CHECK-LABEL: func @aliasing_reads_writes
|
||||
func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () {
|
||||
// expected-remark@above {{ID: 14}}
|
||||
// expected-remark@above {{Predecessors: {13}}}
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 12}}
|
||||
// expected-remark@above {{Predecessors: {11}}}
|
||||
// expected-remark@above {{Successors: {13}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Predecessors: {9}}}
|
||||
// expected-remark@above {{Successors: {11}}}
|
||||
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 1}}
|
||||
%vh1_id:2 = "tf.IdentityN"(%vh1, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>)
|
||||
// expected-remark@above {{ID: 2}}
|
||||
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
"tf.AssignVariableOp"(%vh1_id#0, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {3}}}
|
||||
// expected-remark@above {{Successors: {5,6}}}
|
||||
%read1 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
%read2 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
"tf.AssignVariableOp"(%vh0, %read2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {5,6}}}
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
"tf.AssignVariableOp"(%vh1_id#0, %read1) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Predecessors: {7}}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
// expected-remark@above {{Successors: {10}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
// expected-remark@above {{Successors: {12}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 13}}
|
||||
// expected-remark@above {{Predecessors: {12}}}
|
||||
// expected-remark@above {{Successors: {14}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass tracks control dependencies for side-effecting on unknown
|
||||
// resources.
|
||||
|
||||
// CHECK-LABEL: func @unknown_side_effecting_op
|
||||
func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
|
||||
// expected-remark@above {{ID: 13}}
|
||||
// expected-remark@above {{Predecessors: {12}}}
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
// expected-remark@above {{Successors: {12}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
// expected-remark@above {{Successors: {10}}}
|
||||
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 1}}
|
||||
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
"tf.AssignVariableOp"(%vh1, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
"tf._UnknownSideEffectingOp_"() : () -> ()
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {2,3}}}
|
||||
// expected-remark@above {{Successors: {5,6}}}
|
||||
%read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
"tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
"tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {5}}}
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Predecessors: {6,7}}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Predecessors: {9}}}
|
||||
// expected-remark@above {{Successors: {11}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 12}}
|
||||
// expected-remark@above {{Predecessors: {11}}}
|
||||
// expected-remark@above {{Successors: {13}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass tracks control dependencies for read-only ops on unknown
|
||||
// resources.
|
||||
|
||||
// CHECK-LABEL: func @read_only_unknown_resource
|
||||
func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () {
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Predecessors: {9}}}
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Predecessors: {7}}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {5}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
%vh0 = "tf._UnknownSideEffectingOp_"() : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
// expected-remark@above {{Successors: {2,3}}}
|
||||
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 1}}
|
||||
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {0}}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
%read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Predecessors: {0}}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
"tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {2,3}}}
|
||||
// expected-remark@above {{Successors: {5}}}
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
// expected-remark@above {{Successors: {6}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {6}}}
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
// expected-remark@above {{Successors: {10}}}
|
||||
}
|
@ -1650,3 +1650,27 @@ func @testSplitSmallSplitDim(%input: tensor<4x8xf32>) {
|
||||
%0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x8xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTernaryEinsum(%arg0: tensor<2x3xf32>){
|
||||
// expected-error @+1 {{supports at most two operands}}
|
||||
%0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTopKV2WrongInputRank(%input: tensor<f32>, %k: tensor<i32>) {
|
||||
// expected-error @+1 {{op requires input operand to have at least 1 dimension}}
|
||||
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<f32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) {
|
||||
// expected-error @+1 {{op requires k operand to be 0D tensor}}
|
||||
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||
return
|
||||
}
|
||||
|
@ -523,7 +523,7 @@ func @invalid_merge(%arg0: tensor<*x!tf.resource>, %arg1: tensor<4x!tf.resource>
|
||||
// -----
|
||||
|
||||
// Check that if result is a ref type, all operands need to be ref too.
|
||||
func @inavlid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> {
|
||||
func @invalid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> {
|
||||
%result = tf_executor.graph {
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.Merge' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}
|
||||
|
@ -68,14 +68,17 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) {
|
||||
namespace TF {
|
||||
|
||||
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
|
||||
bool enable_logging) {
|
||||
bool enable_logging,
|
||||
bool enable_inliner) {
|
||||
PassManager bridge(module.getContext());
|
||||
|
||||
// Add logger to bridge passmanager.
|
||||
if (enable_logging)
|
||||
bridge.addInstrumentation(std::make_unique<tensorflow::BridgeLogger>());
|
||||
|
||||
CreateTFStandardPipeline(bridge);
|
||||
StandardPipelineOptions pipeline_options;
|
||||
pipeline_options.enable_inliner.setValue(enable_inliner);
|
||||
CreateTFStandardPipeline(bridge, pipeline_options);
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
LogicalResult result = bridge.run(module);
|
||||
(void)result;
|
||||
|
@ -31,11 +31,13 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging);
|
||||
|
||||
namespace TF {
|
||||
|
||||
// Run all passes involved in transforming or optimizing an MLIR graph without
|
||||
// Runs all passes involved in transforming or optimizing an MLIR graph without
|
||||
// any target specialization. When enable_logging is true, enables
|
||||
// tensorflow::BridgeLogger.
|
||||
// tensorflow::BridgeLogger. When enable_inliner is true, enables the inliner
|
||||
// pass.
|
||||
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
|
||||
bool enable_logging);
|
||||
bool enable_logging,
|
||||
bool enable_inliner);
|
||||
|
||||
} // namespace TF
|
||||
|
||||
|
@ -29,10 +29,4 @@ mlir::PassPipelineRegistration<> tpu_pipeline(
|
||||
"that it is suitable for targeting TPUs.",
|
||||
mlir::TFTPU::CreateTPUBridge);
|
||||
|
||||
mlir::PassPipelineRegistration<> standard_pipeline(
|
||||
"tf-standard-bridge",
|
||||
"Run all passes involved in transforming or optimizing an MLIR graph"
|
||||
"without any target specialization.",
|
||||
mlir::TF::CreateTFStandardPipeline);
|
||||
|
||||
} // anonymous namespace
|
||||
|
@ -22,6 +22,9 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
def SingleResultAndOperandHaveSameElementType : Constraint<
|
||||
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
|
||||
|
||||
def SingleResultAndOperandHaveSameType : Constraint<
|
||||
CPred<"$0->getType() == $1->getType()">>;
|
||||
|
||||
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -75,8 +78,7 @@ def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
|
||||
|
||||
def CastSameType : Pat<(TF_CastOp:$res $arg, $truncate),
|
||||
(replaceWithValue $arg),
|
||||
[(SingleResultAndOperandHaveSameElementType $res,
|
||||
$arg)]>;
|
||||
[(SingleResultAndOperandHaveSameType $res, $arg)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conj op patterns.
|
||||
|
@ -45,7 +45,8 @@ struct TFOptimizePass : public FunctionPass<TFOptimizePass> {
|
||||
} // namespace
|
||||
|
||||
// NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
|
||||
void CreateTFStandardPipeline(OpPassManager &pm) {
|
||||
void CreateTFStandardPipeline(OpPassManager &pm,
|
||||
const StandardPipelineOptions &options) {
|
||||
OpPassManager &func_pm = pm.nest<FuncOp>();
|
||||
|
||||
// First operates on the executor dialect:
|
||||
@ -59,8 +60,12 @@ void CreateTFStandardPipeline(OpPassManager &pm) {
|
||||
// Hopefully there is a single island left, or there wasn't any to begin with.
|
||||
// We now run the optimizer which operates mostly inside islands.
|
||||
func_pm.addPass(createCanonicalizerPass());
|
||||
func_pm.addPass(CreateTFOptimizePass());
|
||||
func_pm.addPass(createCSEPass());
|
||||
if (options.enable_inliner) {
|
||||
pm.addPass(createInlinerPass());
|
||||
}
|
||||
pm.addNestedPass<FuncOp>(CreateTFShapeInferencePass());
|
||||
pm.addNestedPass<FuncOp>(CreateTFOptimizePass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass() {
|
||||
@ -70,7 +75,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass() {
|
||||
static PassRegistration<TFOptimizePass> pass("tf-optimize", "Optimizes TF.");
|
||||
|
||||
// Registers a pipeline builder function for the default canonicalize/optimizer.
|
||||
static mlir::PassPipelineRegistration<> pipeline(
|
||||
static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
|
||||
"tf-standard-pipeline",
|
||||
"Run all the passes involved in transforming/optimizing the graph after "
|
||||
"importing into MLIR, without any target specialization.",
|
||||
|
@ -46,10 +46,17 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateTFShapeInferencePass();
|
||||
// Optimizes Tensorflow graph.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass();
|
||||
|
||||
struct StandardPipelineOptions : public PassOptions<StandardPipelineOptions> {
|
||||
Option<bool> enable_inliner{*this, "enable-inliner",
|
||||
llvm::cl::desc("Enable inliner."),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
// Propagates the pass manager with the passes involved in transforming or
|
||||
// optimizing an MLIR graph without any target specialization.
|
||||
// NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
|
||||
void CreateTFStandardPipeline(OpPassManager& pm);
|
||||
void CreateTFStandardPipeline(OpPassManager& pm,
|
||||
const StandardPipelineOptions& options);
|
||||
} // namespace TF
|
||||
|
||||
namespace TFControlFlow {
|
||||
|
@ -0,0 +1,77 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_executor {
|
||||
|
||||
namespace {
|
||||
|
||||
// A pass that adds "Predecessors" and "Successors" remarks for each op based on
|
||||
// SideEffectAnalysis result. For testing purpose only.
|
||||
struct TestSideEffectAnalysis
|
||||
: public mlir::FunctionPass<TestSideEffectAnalysis> {
|
||||
void runOnFunction() override {
|
||||
int64_t next_id = 0;
|
||||
llvm::SmallDenseMap<Operation*, int64_t, 8> ids;
|
||||
getFunction().walk([&](Operation* op) {
|
||||
ids[op] = next_id++;
|
||||
op->emitRemark("ID: ") << ids[op];
|
||||
});
|
||||
auto join_ids = [&](const llvm::ArrayRef<Operation*> ops) {
|
||||
llvm::SmallVector<std::string, 8> id_vec;
|
||||
id_vec.reserve(ops.size());
|
||||
for (auto op : ops) id_vec.push_back(std::to_string(ids[op]));
|
||||
return llvm::join(id_vec, ",");
|
||||
};
|
||||
auto& analysis = getAnalysis<TF::SideEffectAnalysis>();
|
||||
getFunction().walk([&](Operation* op) {
|
||||
if (!analysis.DirectControlPredecessors(op).empty()) {
|
||||
op->emitRemark("Predecessors: ")
|
||||
<< "{" << join_ids(analysis.DirectControlPredecessors(op)) << "}";
|
||||
}
|
||||
if (!analysis.DirectControlSuccessors(op).empty()) {
|
||||
op->emitRemark("Successors: ")
|
||||
<< "{" << join_ids(analysis.DirectControlSuccessors(op)) << "}";
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
static mlir::PassRegistration<TestSideEffectAnalysis> pass(
|
||||
"tf-test-side-effect-analysis",
|
||||
"Add remarks based on side-effect analysis result, for testing purpose.");
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
} // namespace tf_executor
|
||||
} // namespace mlir
|
@ -141,7 +141,7 @@ static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<std::string> cl_pass_list(
|
||||
"graph-passes", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma seprarated list of GraphOptimizationPass to run."),
|
||||
llvm::cl::desc("comma separated list of GraphOptimizationPass to run."),
|
||||
llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory));
|
||||
|
||||
class GraphOptByNamePass : public GraphOptPass {
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
@ -75,6 +76,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -500,7 +502,8 @@ Status ImporterBase::GetInputOutputNodes(
|
||||
TF_RETURN_IF_ERROR(add_node(input.first));
|
||||
}
|
||||
|
||||
for (const auto& output_node_name : specs_.output_arrays) {
|
||||
for (const auto& output : specs_.outputs) {
|
||||
auto output_node_name = std::string(ParseTensorName(output).first);
|
||||
TF_RETURN_IF_ERROR(add_node(output_node_name));
|
||||
}
|
||||
|
||||
@ -535,7 +538,7 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
auto node_name = node->op_def().name();
|
||||
if (node_name != "Placeholder" && node_name != "LegacyFedInput" &&
|
||||
node_name != FunctionLibraryDefinition::kArgOp) {
|
||||
// We do not handle the case where the input node has multple outputs
|
||||
// We do not handle the case where the input node has multiple outputs
|
||||
if (node->num_outputs() > 1) {
|
||||
return errors::FailedPrecondition(absl::StrCat(
|
||||
"Input arrays can only have op with single output. Node op:",
|
||||
@ -1588,7 +1591,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
|
||||
if (specs.graph_as_function) {
|
||||
if (specs.prune_unused_nodes || !specs.inputs.empty() ||
|
||||
!specs.output_arrays.empty() || !specs.output_arrays_order.empty())
|
||||
!specs.outputs.empty())
|
||||
return errors::InvalidArgument(
|
||||
"Pruning of graph is currently unsupported when the main graph is "
|
||||
"converted to a function.");
|
||||
@ -1622,7 +1625,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
// TODO(prakalps): Refactor to keep attribute strings (tf.entry_function,
|
||||
// tf.versions) shared by importer and exporter in a centralized place.
|
||||
// Record the input and output mapping.
|
||||
if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
|
||||
if (!specs.inputs.empty() || !specs.outputs.empty()) {
|
||||
mlir::Builder b(context);
|
||||
std::string s;
|
||||
llvm::raw_string_ostream ss(s);
|
||||
@ -1632,7 +1635,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
",");
|
||||
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
|
||||
s.clear();
|
||||
mlir::interleave(specs.output_arrays_order, ss, ",");
|
||||
mlir::interleave(specs.outputs, ss, ",");
|
||||
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
|
||||
|
||||
attrs.push_back(b.getNamedAttr("tf.entry_function",
|
||||
@ -1665,9 +1668,13 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
|
||||
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
|
||||
absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
|
||||
// Finds out all the input nodes and output nodes.
|
||||
if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
|
||||
absl::flat_hash_set<absl::string_view> output_node_names;
|
||||
for (const auto& output_tensor : specs.outputs) {
|
||||
output_node_names.insert(ParseTensorName(output_tensor).node());
|
||||
}
|
||||
if (!specs.inputs.empty() || !specs.outputs.empty()) {
|
||||
arg_nodes->resize(specs.inputs.size());
|
||||
ret_nodes->resize(specs.output_arrays_order.size());
|
||||
ret_nodes->resize(specs.outputs.size());
|
||||
|
||||
for (Node* n : GetOrderedNodes()) {
|
||||
// Handle inputs/arguments.
|
||||
@ -1677,17 +1684,17 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
|
||||
}
|
||||
|
||||
// Handle outputs/returns.
|
||||
if (specs.output_arrays.find(n->name()) != specs.output_arrays.end()) {
|
||||
for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
|
||||
if (output_node_names.contains(n->name())) {
|
||||
for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
|
||||
std::pair<std::string, std::string> name_and_port =
|
||||
absl::StrSplit(specs.output_arrays_order[i], ':');
|
||||
absl::StrSplit(specs.outputs[i], ':');
|
||||
auto name = name_and_port.first;
|
||||
if (name != n->name()) continue;
|
||||
int port = 0;
|
||||
if (!name_and_port.second.empty() &&
|
||||
!absl::SimpleAtoi(name_and_port.second, &port)) {
|
||||
return errors::InvalidArgument("Invalid port specification: ",
|
||||
specs.output_arrays_order[i]);
|
||||
specs.outputs[i]);
|
||||
}
|
||||
(*ret_nodes)[i] = {n, port};
|
||||
}
|
||||
@ -1726,10 +1733,10 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
ret_types.reserve(specs.output_arrays.size());
|
||||
for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
|
||||
ret_types.reserve(specs.outputs.size());
|
||||
for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
|
||||
if (ret_nodes->at(i).node == nullptr) {
|
||||
return errors::InvalidArgument("Output ", specs.output_arrays_order[i],
|
||||
return errors::InvalidArgument("Output ", specs.outputs[i],
|
||||
" was not found in graph");
|
||||
}
|
||||
}
|
||||
|
@ -33,19 +33,16 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
Status ParseOutputArrayInfo(absl::string_view array_names,
|
||||
absl::flat_hash_set<string>* array,
|
||||
std::vector<string>* order) {
|
||||
std::vector<string>* outputs) {
|
||||
std::vector<string> output_names = absl::StrSplit(array_names, ',');
|
||||
return ParseOutputArrayInfo(output_names, array, order);
|
||||
return ParseOutputArrayInfo(output_names, outputs);
|
||||
}
|
||||
|
||||
Status ParseOutputArrayInfo(const std::vector<string>& output_names,
|
||||
absl::flat_hash_set<string>* array,
|
||||
std::vector<string>* order) {
|
||||
std::vector<string>* outputs) {
|
||||
for (auto& output_name : output_names) {
|
||||
if (output_name.empty()) continue;
|
||||
array->insert(string(*absl::StrSplit(output_name, ':').begin()));
|
||||
order->push_back(output_name);
|
||||
outputs->push_back(output_name);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -40,11 +40,9 @@ struct GraphImportConfig {
|
||||
llvm::MapVector<string, ArrayInfo, llvm::StringMap<unsigned>>;
|
||||
// Maps input node names to node data types and shapes.
|
||||
InputArrays inputs;
|
||||
// Output node names.
|
||||
absl::flat_hash_set<string> output_arrays;
|
||||
// nodes:index strings for the output as specified on the command line.
|
||||
std::vector<string> output_arrays_order;
|
||||
// setting prune_unused_nodes to true, would prune unreachable nodes if
|
||||
// name:index strings for the output as specified on the command line.
|
||||
std::vector<string> outputs;
|
||||
// Setting prune_unused_nodes to true, would prune unreachable nodes if
|
||||
// output_arrays is specified.
|
||||
bool prune_unused_nodes = false;
|
||||
// If true, inputs of type LegacyFedInput are replaced with Placeholder ops.
|
||||
@ -73,12 +71,10 @@ struct GraphExportConfig {
|
||||
// Parses the command line flag strings to the specification of nodes in
|
||||
// the Graph.
|
||||
Status ParseOutputArrayInfo(absl::string_view array_names,
|
||||
absl::flat_hash_set<string>* array,
|
||||
std::vector<string>* order);
|
||||
std::vector<string>* outputs);
|
||||
|
||||
Status ParseOutputArrayInfo(const std::vector<string>& output_names,
|
||||
absl::flat_hash_set<string>* array,
|
||||
std::vector<string>* order);
|
||||
std::vector<string>* outputs);
|
||||
|
||||
// Parses the command line flag strings to the specification of nodes in
|
||||
// the Graph. `data_types` input string can be empty since the flag is optional.
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
@ -63,16 +64,18 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
||||
specs.upgrade_legacy = upgrade_legacy;
|
||||
TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
|
||||
input_shapes, &specs.inputs));
|
||||
TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.output_arrays,
|
||||
&specs.output_arrays_order));
|
||||
TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs));
|
||||
// TODO(b/142828368): Pruning should not be needed when TF import
|
||||
// supports importing graphs w/ unregistered ops natively.
|
||||
GraphDef pruned_graph_def;
|
||||
if (specs.prune_unused_nodes) {
|
||||
std::vector<string> terminal_nodes(specs.output_arrays.begin(),
|
||||
specs.output_arrays.end());
|
||||
for (const auto entry : specs.inputs) {
|
||||
terminal_nodes.push_back(entry.first);
|
||||
std::vector<std::string> terminal_nodes;
|
||||
terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size());
|
||||
for (const auto& output : specs.outputs) {
|
||||
terminal_nodes.push_back(std::string(ParseTensorName(output).node()));
|
||||
}
|
||||
for (const auto& input : specs.inputs) {
|
||||
terminal_nodes.push_back(input.first);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
|
||||
graphdef, &pruned_graph_def, terminal_nodes));
|
||||
|
@ -36,7 +36,7 @@ xla::StatusOr<xla::Shape> TestShapeRepresentation(const TensorShape& shape,
|
||||
return xla_shape;
|
||||
}
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerliazedMlirModule) {
|
||||
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
|
||||
string invalid_mlir_module = "totally @invalid MLIR module {here} <-";
|
||||
std::vector<TensorShape> arg_shapes;
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
@ -101,7 +101,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
|
||||
xla::ShapeUtil::MakeTupleShape({output_shape});
|
||||
EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape);
|
||||
|
||||
// Expect exactly 1 OutputDescrpition.
|
||||
// Expect exactly 1 OutputDescription.
|
||||
EXPECT_EQ(compilation_result.outputs.size(), 1);
|
||||
const XlaCompiler::OutputDescription& output_desc =
|
||||
compilation_result.outputs.front();
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/ToolUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/TranslateClParser.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
@ -40,6 +41,13 @@ static llvm::cl::opt<std::string> output_filename(
|
||||
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> splitInputFile(
|
||||
"split-input-file",
|
||||
llvm::cl::desc("Split the input file into pieces and process each chunk "
|
||||
"independently"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> import_saved_model(
|
||||
"savedmodel-to-mlir",
|
||||
@ -85,13 +93,12 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
mlir::MLIRContext context;
|
||||
|
||||
if (import_saved_model) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
|
||||
auto module = tensorflow::SavedModelToMlirImport(
|
||||
input_filename, tags, absl::Span<std::string>(exported_names),
|
||||
@ -107,12 +114,23 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
llvm::SourceMgr source_mgr;
|
||||
source_mgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
|
||||
mlir::SourceMgrDiagnosticHandler diagnostic_handler(source_mgr, &context);
|
||||
// Processes the memory buffer with a new MLIRContext.
|
||||
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
|
||||
llvm::raw_ostream& os) {
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
|
||||
mlir::MLIRContext context;
|
||||
mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context);
|
||||
return (*requested_translation)(sourceMgr, os, &context);
|
||||
};
|
||||
|
||||
if (failed((*requested_translation)(source_mgr, output->os(), &context)))
|
||||
return 1;
|
||||
if (splitInputFile) {
|
||||
if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer,
|
||||
output->os())))
|
||||
return 1;
|
||||
} else {
|
||||
if (failed(processBuffer(std::move(input), output->os()))) return 1;
|
||||
}
|
||||
}
|
||||
|
||||
output->keep();
|
||||
|
@ -404,6 +404,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
|
||||
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||
@ -82,3 +83,4 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
|
||||
// Converts the given elements attr to the specified elements type.
|
||||
@ -27,5 +28,6 @@ namespace xla {
|
||||
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||
mlir::Type new_type);
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
|
||||
|
@ -47,13 +47,11 @@ limitations under the License.
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc"
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::xla_hlo;
|
||||
namespace xla_hlo {
|
||||
|
||||
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
||||
Attribute value, Type type,
|
||||
@ -160,7 +158,7 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) {
|
||||
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
|
||||
value.isa<IntegerAttr>()) {
|
||||
// All XLA types must be tensor types. In the build() method, we want to
|
||||
// provide more flexiblity by allowing attributes of scalar types. But we
|
||||
// provide more flexibility by allowing attributes of scalar types. But we
|
||||
// need to wrap it up with ElementsAttr to construct valid XLA constants.
|
||||
type = RankedTensorType::get(/*shape=*/{}, value.getType());
|
||||
value = DenseElementsAttr::get(type.cast<TensorType>(), value);
|
||||
@ -212,9 +210,9 @@ void AbsOp::build(Builder* builder, OperationState& result, Value* operand) {
|
||||
new_type = operand->getType();
|
||||
} else if (shaped_type.hasRank()) {
|
||||
new_type =
|
||||
mlir::RankedTensorType::get(shaped_type.getShape(), operand->getType());
|
||||
RankedTensorType::get(shaped_type.getShape(), operand->getType());
|
||||
} else {
|
||||
new_type = mlir::UnrankedTensorType::get(operand->getType());
|
||||
new_type = UnrankedTensorType::get(operand->getType());
|
||||
}
|
||||
|
||||
return AbsOp::build(builder, result, new_type, operand);
|
||||
@ -241,8 +239,8 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
// If the operand is constant, we can do the conversion now.
|
||||
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
|
||||
return ::xla::ConvertElementsAttr(elementsAttr,
|
||||
getElementTypeOrSelf(getResult()));
|
||||
return xla::ConvertElementsAttr(elementsAttr,
|
||||
getElementTypeOrSelf(getResult()));
|
||||
}
|
||||
|
||||
return {};
|
||||
@ -436,7 +434,7 @@ static LogicalResult Verify(ClampOp op) {
|
||||
void ComplexOp::build(Builder* builder, OperationState& state, Value* lhs,
|
||||
Value* rhs) {
|
||||
auto type = lhs->getType();
|
||||
auto element_ty = mlir::ComplexType::get(getElementTypeOrSelf(type));
|
||||
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
|
||||
Type result_ty;
|
||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
||||
result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
|
||||
@ -843,6 +841,70 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand,
|
||||
return RankedTensorType::get(shape, ranked_ty.getElementType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SortOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SortOp::build(Builder* builder, OperationState& state,
|
||||
ArrayRef<Value*> operands, int64_t dimension,
|
||||
bool is_stable) {
|
||||
state.addOperands(operands);
|
||||
state.addAttribute("dimension", builder->getI64IntegerAttr(dimension));
|
||||
state.addAttribute("is_stable", builder->getBoolAttr(dimension));
|
||||
|
||||
SmallVector<Type, 2> element_types;
|
||||
element_types.reserve(operands.size());
|
||||
for (Value* operand : operands) element_types.push_back(operand->getType());
|
||||
state.addTypes(builder->getTupleType(element_types));
|
||||
|
||||
state.addRegion();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(SortOp op) {
|
||||
Operation::operand_range operands = op.operands();
|
||||
if (operands.empty()) return op.emitOpError("requires at least one input");
|
||||
|
||||
// TODO(antiagainst): verify partionally dynamic shapes
|
||||
if (llvm::all_of(operands, [](Value* operand) {
|
||||
return operand->getType().cast<ShapedType>().hasRank();
|
||||
})) {
|
||||
ArrayRef<int64_t> input_shape =
|
||||
(*operands.begin())->getType().cast<ShapedType>().getShape();
|
||||
|
||||
if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value* operand) {
|
||||
return operand->getType().cast<ShapedType>().getShape() !=
|
||||
input_shape;
|
||||
}))
|
||||
return op.emitOpError("requires all inputs to have the same dimensions");
|
||||
|
||||
if (op.dimension().getSExtValue() >= input_shape.size())
|
||||
return op.emitOpError(
|
||||
"dimension attribute value must be less than input rank");
|
||||
}
|
||||
|
||||
Block& block = op.comparator().front();
|
||||
size_t num_operands = op.getOperation()->getNumOperands();
|
||||
if (block.getNumArguments() != 2 * num_operands)
|
||||
return op.emitOpError("comparator block should have ")
|
||||
<< 2 * num_operands << " arguments";
|
||||
|
||||
for (auto indexed_operand : llvm::enumerate(operands)) {
|
||||
int index = indexed_operand.index();
|
||||
Type element_type =
|
||||
indexed_operand.value()->getType().cast<ShapedType>().getElementType();
|
||||
Type tensor_type = RankedTensorType::get({}, element_type);
|
||||
for (int i : {2 * index, 2 * index + 1}) {
|
||||
Type arg_type = block.getArgument(i)->getType();
|
||||
if (arg_type != tensor_type)
|
||||
return op.emitOpError("comparator block argument #")
|
||||
<< i << " should be of type " << tensor_type << " but got "
|
||||
<< arg_type;
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -938,6 +1000,15 @@ void TupleOp::build(Builder* builder, OperationState& result,
|
||||
build(builder, result, builder->getTupleType(types), values);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnaryEinsumOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void UnaryEinsumOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<UnaryEinsumToEinsum>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CompareOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -990,3 +1061,6 @@ XlaHloDialect::XlaHloDialect(MLIRContext* context)
|
||||
// Support unknown operations because not all XLA operations are registered.
|
||||
// allowUnknownOperations();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
@ -437,9 +437,6 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &results, "
|
||||
"Value* value, int32_t index">];
|
||||
|
||||
// GetTupleElementOp has special conversion logic to HLO.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
||||
@ -730,6 +727,43 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral
|
||||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
||||
def BASE_EinsumOp {
|
||||
string summary = "Einsum operator";
|
||||
|
||||
string description = [{
|
||||
Returns a tensor whose elements are defined by equation, which is written
|
||||
in a shorthand form inspired by the Einstein summation convention.
|
||||
}];
|
||||
}
|
||||
|
||||
def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs,
|
||||
StrAttr:$einsum_config
|
||||
);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
// TODO(hinsu): Canonicalize to lower this client side HLO op to server
|
||||
// side HLO ops.
|
||||
}
|
||||
|
||||
def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
StrAttr:$einsum_config
|
||||
);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
// UnarayEinsumOp is unconditionally canonicalized to the binary EinsumOp so
|
||||
// the HLO converter shouldn't be invoked.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
@ -834,6 +868,26 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp {
|
||||
let arguments = (ins
|
||||
Variadic<HLO_Tensor>:$operands,
|
||||
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
||||
);
|
||||
|
||||
let results = (outs HLO_TensorOrTuple);
|
||||
|
||||
let regions = (region SizedRegion<1>:$comparator);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &state, ArrayRef<Value *> operands, "
|
||||
"int64_t dimension, bool is_stable"
|
||||
>];
|
||||
|
||||
// TODO(b/129422361): SortOp has special conversion logic to HLO.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_ReverseOp: HLO_Op<"reverse",
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp {
|
||||
let arguments = (ins
|
||||
|
@ -708,7 +708,7 @@ class BASE_HLO_ClampOp {
|
||||
}
|
||||
|
||||
class BASE_HLO_ConcatenateOp {
|
||||
string summary = "XLA's concantenate op";
|
||||
string summary = "XLA's concatenate op";
|
||||
|
||||
string description = [{
|
||||
Concatenates a set of tensors along the specified dimension.
|
||||
@ -832,6 +832,17 @@ class BASE_HLO_SelectAndScatterOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_SortOp {
|
||||
string summary = "Sort operator";
|
||||
|
||||
string description = [{
|
||||
Sorts the given `operands` at the given `dimension` with the given
|
||||
`comparator`.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#sort.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_ReverseOp {
|
||||
string summary = "Reverse operator";
|
||||
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
|
||||
@ -51,5 +53,18 @@ DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x,
|
||||
return DenseIntElementsAttr::get(type, broadcastDimensions);
|
||||
}
|
||||
|
||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
||||
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||
|
||||
DenseElementsAttr attr;
|
||||
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
}
|
||||
auto int_ty = ty.cast<IntegerType>();
|
||||
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||
@ -48,6 +49,12 @@ static ElementsAttr getSplat(Builder* b, Value* val, T constant) {
|
||||
|
||||
return DenseElementsAttr::get(valType, elementAttr);
|
||||
}
|
||||
|
||||
// Returns DenseElementsAttr of rank zero with the given element type and the
|
||||
// value.
|
||||
// Requires `ty` to be either FloatType of IntegerType.
|
||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -18,20 +18,23 @@ limitations under the License.
|
||||
#ifndef HLO_UTILS
|
||||
#define HLO_UTILS
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
||||
|
||||
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
|
||||
|
||||
class ConstantSplat<string value> : NativeCodeCall<
|
||||
"getSplat(&$_builder, $0, " # value # ")">;
|
||||
"xla::getSplat(&$_builder, $0, " # value # ")">;
|
||||
|
||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||
|
||||
def BinBroadcastDimensions : NativeCodeCall<
|
||||
"getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
|
||||
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
|
||||
|
||||
// Here, the element type can be any integer or float type. But, note that only
|
||||
// 32 bit integers are supported for the value.
|
||||
class GetScalarOfType<int value> : NativeCodeCall<
|
||||
"xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||
|
||||
#endif // HLO_UTILS
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
@ -77,6 +78,10 @@ static double ConvertAPFloat(llvm::APFloat value) {
|
||||
return value.convertToDouble();
|
||||
}
|
||||
|
||||
static absl::string_view ConvertStringRef(mlir::StringRef value) {
|
||||
return {value.data(), value.size()};
|
||||
}
|
||||
|
||||
static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
|
||||
auto values = attr.getValues<int64>();
|
||||
return {values.begin(), values.end()};
|
||||
@ -494,13 +499,6 @@ LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(GetTupleElementOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
value_map[op] = xla::GetTupleElement(value_map[op.getOperand()],
|
||||
op.index().getSExtValue());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
|
||||
@ -626,12 +624,30 @@ LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
|
||||
xla::XlaComputation comparator;
|
||||
if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
|
||||
&comparator)))
|
||||
return failure();
|
||||
|
||||
auto& value_map = *ctx.values;
|
||||
value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator,
|
||||
op.dimension().getSExtValue(), op.is_stable());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
|
||||
// Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
|
||||
// operands.
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
|
||||
xla::XlaComputation condition;
|
||||
xla::XlaComputation body;
|
||||
@ -773,7 +789,7 @@ LogicalResult ConvertToHloModule::LowerFunctionCall(
|
||||
LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
|
||||
if (lowered_computation_.count(f)) return success();
|
||||
if (f.getBlocks().size() != 1) {
|
||||
return f.emitError("only single block Function suppored");
|
||||
return f.emitError("only single block Function supported");
|
||||
}
|
||||
|
||||
// Create a sub-builder if this is not the main function.
|
||||
|
@ -32,17 +32,20 @@ using llvm::raw_ostream;
|
||||
using llvm::RecordKeeper;
|
||||
using llvm::StringRef;
|
||||
using mlir::interleaveComma;
|
||||
using mlir::tblgen::Attribute;
|
||||
using mlir::tblgen::NamedAttribute;
|
||||
using mlir::tblgen::NamedTypeConstraint;
|
||||
using mlir::tblgen::Operator;
|
||||
|
||||
static std::string GetDefaultAttrExport(
|
||||
const mlir::tblgen::NamedAttribute& named_attr) {
|
||||
auto storage_type = named_attr.attr.getStorageType();
|
||||
Attribute attr = named_attr.attr;
|
||||
StringRef storage_type = attr.getStorageType();
|
||||
// For some attribute types we have a general conversion, so use that.
|
||||
if (storage_type.endswith("IntegerAttr") ||
|
||||
storage_type.endswith("FloatAttr")) {
|
||||
return "Convert" + named_attr.attr.getReturnType().str();
|
||||
if (!attr.isEnumAttr() && (storage_type.endswith("IntegerAttr") ||
|
||||
storage_type.endswith("FloatAttr") ||
|
||||
storage_type.endswith("StringAttr"))) {
|
||||
return "Convert" + attr.getReturnType().str();
|
||||
}
|
||||
return "Convert_" + named_attr.name.str();
|
||||
}
|
||||
|
@ -48,3 +48,11 @@ func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f
|
||||
// CHECK: return %arg0
|
||||
return %2 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @unary_einsum
|
||||
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
|
||||
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
@ -180,6 +180,27 @@ func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
return %0: tensor<?xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitwise_or
|
||||
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-NEXT: xla_hlo.or
|
||||
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0: tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitwise_or_broadcast
|
||||
func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
|
||||
// CHECK-NEXT: xla_hlo.or
|
||||
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
return %0: tensor<1x4xi8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitwise_or_dynamic
|
||||
func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
|
||||
// CHECK-NEXT: xla_hlo.or
|
||||
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0: tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @pow
|
||||
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK-NEXT: xla_hlo.pow
|
||||
@ -194,6 +215,20 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
return %0: tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @einsum
|
||||
func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK: xla_hlo.einsum
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
|
||||
return %0: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unary_einsum
|
||||
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: xla_hlo.unary_einsum
|
||||
%0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||
return %0: tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @floordiv_broadcast_i32
|
||||
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0>
|
||||
@ -589,6 +624,17 @@ func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
|
||||
return %1 : tensor<6x9xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @padv2_i32_paddings
|
||||
func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
|
||||
%padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32>
|
||||
// CHECK: "xla_hlo.pad"(%arg0, %arg1) {
|
||||
// CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>,
|
||||
// CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>,
|
||||
// CHECK-SAME: interior_padding = dense<0> : tensor<2xi64>
|
||||
%1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<6x9xf32>
|
||||
return %1 : tensor<6x9xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Identity op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1888,3 +1934,42 @@ func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf
|
||||
// CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
|
||||
return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.TopKV2 legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: topk_v2_non_const_k
|
||||
func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
|
||||
// CHECK: tf.TopKV2
|
||||
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
|
||||
return %0#0, %0#1: tensor<?xf32>, tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: topk_v2_unknown_input_last_dim
|
||||
func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) {
|
||||
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: tf.TopKV2
|
||||
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>)
|
||||
return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: topk_v2
|
||||
// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32>
|
||||
func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
|
||||
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
|
||||
|
||||
// CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64}
|
||||
// CHECK-NEXT: %[[SORT:.*]] = "xla_hlo.sort"(%[[INPUT]], %[[IOTA]]) ( {
|
||||
// CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f32>, %[[RHS:.*]]: tensor<f32>, %{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
|
||||
// CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"}
|
||||
// CHECK-NEXT: "xla_hlo.return"(%[[CMP]])
|
||||
// CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
// CHECK-NEXT: %[[TUPL0:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[TUPL1:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[VAL:.*]] = "xla_hlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[IDX:.*]] = "xla_hlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-NEXT: return %[[VAL]], %[[IDX]]
|
||||
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
|
||||
return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s
|
||||
|
||||
#map0 = (d0, d1) -> (d0, d1)
|
||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], n_loop_types = [2, 0, 0], n_views = [2, 1]}
|
||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"], n_views = [2, 1]}
|
||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%temp_result = alloc() {temp = true} : memref<2x2xf32>
|
||||
|
@ -416,3 +416,98 @@ func @constants() -> () {
|
||||
%3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor<i32>} : () -> (tensor<*xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
// CHECK: xla_hlo.sort
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_no_operands() {
|
||||
// expected-error @+1 {{op requires at least one input}}
|
||||
%0 = "xla_hlo.sort"() ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : () -> tuple<>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
||||
// expected-error @+1 {{comparator block argument #0 should be of type 'tensor<f32>' but got 'tensor<i32>'}}
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) {
|
||||
// expected-error @+1 {{op requires all inputs to have the same dimensions}}
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
// expected-error @+1 {{op dimension attribute value must be less than input rank}}
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
// expected-error @+1 {{op comparator block should have 4 arguments}}
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
// expected-error @+1 {{op comparator block argument #3 should be of type 'tensor<i32>' but got 'tensor<f32>'}}
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
@ -1,26 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||
%0 = "xla_hlo.all_reduce"(%arg0) ({
|
||||
// Perform max reduction inside the region
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.max %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{
|
||||
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
|
||||
channel_id = {
|
||||
handle = 5 : i64,
|
||||
type = 2 : i64
|
||||
}
|
||||
} : (tensor<10xf32>) -> tensor<10xf32>
|
||||
return %0 : tensor<10xf32>
|
||||
}
|
||||
|
||||
// CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
|
||||
// CHECK-SAME: channel_id=5
|
||||
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
|
||||
// CHECK-SAME: to_apply=%[[COMPUTATION]]
|
@ -1,15 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
|
||||
%0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
|
||||
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
|
||||
// CHECK: [[VAL_4:%.*]] = f32[2] parameter(3)
|
||||
// CHECK: [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0
|
@ -1,13 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
|
||||
%0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
|
||||
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3
|
@ -1,23 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
module {
|
||||
func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
// CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
|
||||
// CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
|
||||
return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: ENTRY %main.13 (Arg_0.1: s32[1,4], Arg_1.2: s32[2,4], Arg_2.3: s32[2,3,4]) -> s32[2,3,4] {
|
||||
func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
|
||||
// Same rank degenerate broadcast
|
||||
// CHECK-NEXT: %Arg_0.1 = s32[1,4] parameter(0)
|
||||
// CHECK-NEXT: %reshape.4 = s32[4] reshape(s32[1,4] %Arg_0.1)
|
||||
// CHECK-NEXT: %broadcast.5 = s32[2,4] broadcast(s32[4] %reshape.4)
|
||||
// CHECK-NEXT: %Arg_1.2 = s32[2,4] parameter(1)
|
||||
// CHECK-NEXT: %add.6 = s32[2,4] add(s32[2,4] %broadcast.5, s32[2,4] %Arg_1.2)
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
|
||||
// Broadcast up rank
|
||||
// CHECK-NEXT: %broadcast.7 = s32[2,3,4] broadcast(s32[2,4] %Arg_1.2), dimensions={0,2}
|
||||
// CHECK-NEXT: %Arg_2.3 = s32[2,3,4] parameter(2)
|
||||
// CHECK-NEXT: %add.8 = s32[2,3,4] add(s32[2,3,4] %broadcast.7, s32[2,3,4] %Arg_2.3)
|
||||
%1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
||||
|
||||
// Broadcast up rank + degenerate broadcast
|
||||
// CHECK-NEXT: %broadcast.9 = s32[2,1,4] broadcast(s32[1,4] %Arg_0.1), dimensions={1,2}
|
||||
// CHECK-NEXT: %reshape.10 = s32[2,4] reshape(s32[2,1,4] %broadcast.9)
|
||||
// CHECK-NEXT: %broadcast.11 = s32[2,3,4] broadcast(s32[2,4] %reshape.10), dimensions={0,2}
|
||||
// CHECK-NEXT: ROOT %add.12 = s32[2,3,4] add(s32[2,3,4] %broadcast.11, s32[2,3,4] %Arg_2.3)
|
||||
%2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
||||
return %2 : tensor<2x3x4xi32>
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[1,2,3,4] {
|
||||
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
|
||||
// CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
|
||||
// CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
|
||||
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
|
||||
return %0 : tensor<1x2x3x4xi32>
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
|
||||
%result = "xla_hlo.broadcast_in_dim"(%arg0) {
|
||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||
} : (tensor<1xf32>) -> tensor<1x10xf32>
|
||||
return %result : tensor<1x10xf32>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY %main.3 ([[ARG0:.*]]: f32[1]) -> f32[1,10] {
|
||||
// CHECK: %[[ARG0]] = f32[1] parameter(0)
|
||||
// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] %[[ARG0]]), dimensions={0}
|
@ -1,30 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %1 : tensor<4xi32>
|
||||
}
|
||||
|
||||
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK: [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s32[4], [[ARG_2:.*]]: s32[4]) -> s32[4] {
|
||||
// CHECK: %[[ARG_1]] = s32[4] parameter(0)
|
||||
// CHECK: %[[ARG_2]] = s32[4] parameter(1)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]])
|
||||
|
||||
// CHECK: [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] {
|
||||
// CHECK: %[[ARG_3]] = s32[4] parameter(0)
|
||||
// CHECK: %[[ARG_4]] = s32[4] parameter(1)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]])
|
||||
|
||||
// CHECK: ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] {
|
||||
// CHECK: %[[ARG]] = s32[4] parameter(0)
|
||||
// CHECK: [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]]
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]]
|
@ -1,24 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
|
||||
%0:2 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>)
|
||||
return %0#0, %0#1 : tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0, %1 : tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
// Get name of callee computation
|
||||
// CHECK: [[CALLEE:%.*]] ({{.*}}) -> ({{.*}}) {
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK-SAME: [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) {
|
||||
// CHECK: %[[ARG]] = s32[4] parameter(0)
|
||||
// CHECK: [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]]
|
||||
// CHECK: [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0
|
||||
// CHECK: [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]])
|
@ -1,17 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0 : tensor<5x2xf32>,
|
||||
%arg1 : tensor<5x5xf32>,
|
||||
%arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
|
||||
%result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) {
|
||||
dimension = 1 : i64
|
||||
} : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
|
||||
return %result : tensor<5x14xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: main
|
||||
// CHECK: %[[ARG0:.*]] = f32[5,2] parameter(0)
|
||||
// CHECK: %[[ARG1:.*]] = f32[5,5] parameter(1)
|
||||
// CHECK: %[[ARG2:.*]] = f32[5,7] parameter(2)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1}
|
@ -42,13 +42,13 @@ func @main(%arg0: tensor<f32>) -> tuple<tensor<f32>> {
|
||||
|
||||
// CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]]
|
||||
%2 = "xla_hlo.conditional"(%0, %1, %1) ( {
|
||||
^bb0(%arg1: tuple<tensor<f32>>): // no predecessors
|
||||
^bb0(%arg1: tuple<tensor<f32>>):
|
||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%7 = "xla_hlo.log"(%6) : (tensor<f32>) -> tensor<f32>
|
||||
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tuple<tensor<f32>>): // no predecessors
|
||||
^bb0(%arg1: tuple<tensor<f32>>):
|
||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%7 = "xla_hlo.exp"(%6) : (tensor<f32>) -> tensor<f32>
|
||||
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
|
@ -1,30 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: ENTRY %main
|
||||
func @main() -> tensor<2x2x1x1xf32> {
|
||||
// CHECK: constant.{{.*}} = s64[] constant(1)
|
||||
%cst = constant dense<1> : tensor<i64>
|
||||
// CHECK: constant.{{.*}} = f32[2,2,1,1]
|
||||
// CHECK-SAME: { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } }
|
||||
%cst_0 = constant dense<
|
||||
[[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]]
|
||||
> : tensor<2x2x1x1xf32>
|
||||
|
||||
// CHECK: s32[1] constant({1})
|
||||
%cst_1 = constant dense<1> : tensor<1xi32>
|
||||
|
||||
// CHECK: %[[C:.*]] = s32[] constant(1)
|
||||
// CHECK: s32[10] broadcast(s32[] %[[C]])
|
||||
%cst_2 = constant dense<1> : tensor<10xi32>
|
||||
|
||||
// CHECK: s32[4] constant({1, 2, 3, 4})
|
||||
%cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
||||
|
||||
// CHECK: s32[2,2] constant({ { 1, 2 }, { 3, 4 } })
|
||||
%cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
||||
|
||||
// CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
|
||||
%cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
|
||||
|
||||
return %cst_0 : tensor<2x2x1x1xf32>
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
|
||||
%result = "xla_hlo.conv"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 3 : i64,
|
||||
kernel_output_feature_dimension = 2 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||
padding = dense<2> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32>
|
||||
return %result : tensor<100x28x28x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: main
|
||||
// CHECK: %[[ARG0:.*]] = f32[100,26,26,32] parameter(0)
|
||||
// CHECK: %[[ARG1:.*]] = f32[3,3,1,32] parameter(1)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]),
|
||||
// CHECK-SAME: window={size=3x3 pad=2_2x2_2},
|
||||
// CHECK-SAME: dim_labels=b01f_01oi->b01f
|
@ -1,10 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY %main
|
||||
// CHECK: %[[ARG:.*]] = s32[2] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]])
|
@ -1,10 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY %main
|
||||
// CHECK: [[ARG:%.*]] = s32[2] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]])
|
@ -1,16 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||
%0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
|
||||
%1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
|
||||
return %1 : tensor<10xf32>
|
||||
}
|
||||
|
||||
// CHECK: %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[]
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
|
||||
|
||||
// CHECK: ENTRY %main
|
||||
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
|
||||
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
|
||||
// CHECK-SAME: to_apply=%[[SUM_COMPUTATION]]
|
640
tensorflow/compiler/mlir/xla/tests/translate/export.mlir
Normal file
640
tensorflow/compiler/mlir/xla/tests/translate/export.mlir
Normal file
@ -0,0 +1,640 @@
|
||||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||
%0 = "xla_hlo.all_reduce"(%arg0) ({
|
||||
// Perform max reduction inside the region
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.max %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{
|
||||
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
|
||||
channel_id = {
|
||||
handle = 5 : i64,
|
||||
type = 2 : i64
|
||||
}
|
||||
} : (tensor<10xf32>) -> tensor<10xf32>
|
||||
return %0 : tensor<10xf32>
|
||||
}
|
||||
|
||||
// CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
|
||||
// CHECK-SAME: channel_id=5
|
||||
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
|
||||
// CHECK-SAME: to_apply=%[[COMPUTATION]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
|
||||
%0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
|
||||
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
|
||||
// CHECK: [[VAL_4:%.*]] = f32[2] parameter(3)
|
||||
// CHECK: [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
|
||||
%0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
|
||||
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
// CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
|
||||
// CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
|
||||
return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
|
||||
// Same rank degenerate broadcast
|
||||
// CHECK: [[ARG_0:%.*]] = s32[1,4] parameter(0)
|
||||
// CHECK-NEXT: [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]])
|
||||
// CHECK-NEXT: [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]])
|
||||
// CHECK-NEXT: [[ARG_1:%.*]] = s32[2,4] parameter(1)
|
||||
// CHECK-NEXT: s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]])
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
|
||||
// Broadcast up rank
|
||||
// CHECK-NEXT: [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2}
|
||||
// CHECK-NEXT: [[ARG_2:%.*]] = s32[2,3,4] parameter(2)
|
||||
// CHECK-NEXT: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]])
|
||||
%1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
||||
|
||||
// Broadcast up rank + degenerate broadcast
|
||||
// CHECK-NEXT: [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2}
|
||||
// CHECK-NEXT: [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]])
|
||||
// CHECK-NEXT: [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2}
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]])
|
||||
%2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
||||
return %2 : tensor<2x3x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
|
||||
// CHECK: [[ARG:%.*]] = s32[4] parameter(0)
|
||||
// CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] [[ARG]]), dimensions={3}
|
||||
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
|
||||
return %0 : tensor<1x2x3x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
|
||||
%result = "xla_hlo.broadcast_in_dim"(%arg0) {
|
||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||
} : (tensor<1xf32>) -> tensor<1x10xf32>
|
||||
return %result : tensor<1x10xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[ARG:%.*]] = f32[1] parameter(0)
|
||||
// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] [[ARG]]), dimensions={0}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %1 : tensor<4xi32>
|
||||
}
|
||||
|
||||
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK: [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s32[4], [[ARG_2:.*]]: s32[4]) -> s32[4] {
|
||||
// CHECK: %[[ARG_1]] = s32[4] parameter(0)
|
||||
// CHECK: %[[ARG_2]] = s32[4] parameter(1)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]])
|
||||
|
||||
// CHECK: [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] {
|
||||
// CHECK: %[[ARG_3]] = s32[4] parameter(0)
|
||||
// CHECK: %[[ARG_4]] = s32[4] parameter(1)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]])
|
||||
|
||||
// CHECK: ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] {
|
||||
// CHECK: %[[ARG]] = s32[4] parameter(0)
|
||||
// CHECK: [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]]
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
|
||||
%0:2 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>)
|
||||
return %0#0, %0#1 : tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0, %1 : tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
// Get name of callee computation
|
||||
// CHECK: [[CALLEE:%.*]] ({{.*}}) -> ({{.*}}) {
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK-SAME: [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) {
|
||||
// CHECK: %[[ARG]] = s32[4] parameter(0)
|
||||
// CHECK: [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]]
|
||||
// CHECK: [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0
|
||||
// CHECK: [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0 : tensor<5x2xf32>,
|
||||
%arg1 : tensor<5x5xf32>,
|
||||
%arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
|
||||
%result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) {
|
||||
dimension = 1 : i64
|
||||
} : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
|
||||
return %result : tensor<5x14xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[5,2] parameter(0)
|
||||
// CHECK: %[[ARG1:.*]] = f32[5,5] parameter(1)
|
||||
// CHECK: %[[ARG2:.*]] = f32[5,7] parameter(2)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main() -> tensor<2x2x1x1xf32> {
|
||||
// CHECK: constant.{{.*}} = s64[] constant(1)
|
||||
%cst = constant dense<1> : tensor<i64>
|
||||
// CHECK: constant.{{.*}} = f32[2,2,1,1]
|
||||
// CHECK-SAME: { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } }
|
||||
%cst_0 = constant dense<
|
||||
[[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]]
|
||||
> : tensor<2x2x1x1xf32>
|
||||
|
||||
// CHECK: s32[1] constant({1})
|
||||
%cst_1 = constant dense<1> : tensor<1xi32>
|
||||
|
||||
// CHECK: %[[C:.*]] = s32[] constant(1)
|
||||
// CHECK: s32[10] broadcast(s32[] %[[C]])
|
||||
%cst_2 = constant dense<1> : tensor<10xi32>
|
||||
|
||||
// CHECK: s32[4] constant({1, 2, 3, 4})
|
||||
%cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
||||
|
||||
// CHECK: s32[2,2] constant({ { 1, 2 }, { 3, 4 } })
|
||||
%cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
||||
|
||||
// CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
|
||||
%cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
|
||||
|
||||
return %cst_0 : tensor<2x2x1x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
|
||||
%result = "xla_hlo.conv"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 3 : i64,
|
||||
kernel_output_feature_dimension = 2 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||
padding = dense<2> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32>
|
||||
return %result : tensor<100x28x28x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[100,26,26,32] parameter(0)
|
||||
// CHECK: %[[ARG1:.*]] = f32[3,3,1,32] parameter(1)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]),
|
||||
// CHECK-SAME: window={size=3x3 pad=2_2x2_2},
|
||||
// CHECK-SAME: dim_labels=b01f_01oi->b01f
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG:.*]] = s32[2] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[ARG:%.*]] = s32[2] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||
%0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
|
||||
%1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
|
||||
return %1 : tensor<10xf32>
|
||||
}
|
||||
|
||||
// CHECK: %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[]
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
|
||||
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
|
||||
// CHECK-SAME: to_apply=%[[SUM_COMPUTATION]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
|
||||
// Simple einsum is lowered to HLO dot op.
|
||||
// CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
%0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
|
||||
return %0 : tensor<3x5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
|
||||
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
|
||||
// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = (f32[], s32[]) parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main() -> tensor<1x10xf32> {
|
||||
%result = "xla_hlo.iota"() {
|
||||
iota_dimension = 1 : i64
|
||||
} : () -> tensor<1x10xf32>
|
||||
return %result : tensor<1x10xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[1,10] iota(), iota_dimension=1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
|
||||
%0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
|
||||
return %0 : tensor<13x19xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[ARG:%.*]] = f32[4,6] parameter(0)
|
||||
// CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
|
||||
%result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
|
||||
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
|
||||
%fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%imax = "xla_hlo.max"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"xla_hlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
|
||||
return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK: %[[REGION:region_[0-9]+]]
|
||||
// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[])
|
||||
// CHECK: %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]])
|
||||
// CHECK: %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]])
|
||||
// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]])
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG1:.*]]: s32[1,10], [[ARG2:.*]]: f32[], [[ARG3:.*]]: s32[]) -> (f32[1], s32[1])
|
||||
// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %[[ARG0]], s32[1,10] %[[ARG1]], f32[] %[[ARG2]], s32[] %[[ARG3]]), dimensions={1}, to_apply=%[[REGION]]
|
||||
// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0
|
||||
// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1
|
||||
// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> {
|
||||
%0 = xla_hlo.constant dense<-2147483648> : tensor<i32>
|
||||
%1 = "xla_hlo.reduce_window"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
|
||||
%2 = xla_hlo.max %arg1, %arg2 : tensor<i32>
|
||||
"xla_hlo.return"(%2) : (tensor<i32>) -> ()
|
||||
}) {
|
||||
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
|
||||
window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>,
|
||||
padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>,
|
||||
base_dilations = dense<[1, 1, 1, 1]> : tensor<4xi64>,
|
||||
window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>
|
||||
} : (tensor<2x17x31x7xi32>, tensor<i32>) -> tensor<2x3x5x7xi32>
|
||||
return %1 : tensor<2x3x5x7xi32>
|
||||
}
|
||||
|
||||
// CHECK: %[[MAX_COMPUTATION:.*]] ([[ARG0:.*]]: s32[], [[ARG1:.*]]: s32[]) -> s32[]
|
||||
// CHECK: ROOT %[[RESULT:.*]] = s32[] maximum(s32[] %[[ARG0]], s32[] %[[ARG1]])
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK-DAG: %[[ARG0:.*]] = s32[2,17,31,7] parameter(0)
|
||||
// CHECK-DAG: %[[INIT:.*]] = s32[] constant(-2147483648)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = s32[2,5,8,7] reduce-window(s32[2,17,31,7] %[[ARG0]], s32[] %constant.2),
|
||||
// CHECK-SAME: window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1},
|
||||
// CHECK-SAME: to_apply=%[[MAX_COMPUTATION]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32>
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[2] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[1,2] reshape(f32[2] %[[ARG0]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
|
||||
%result = "xla_hlo.reverse"(%arg0) {
|
||||
dimensions = dense<[1,2]> : tensor<2xi64>
|
||||
} : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32>
|
||||
return %result : tensor<10x11x12x13xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[10,11,12,13] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[10,11,12,13] reverse(f32[10,11,12,13] %[[ARG0]]), dimensions={1,2}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main() -> tensor<2x3x5xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
%2 = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
|
||||
%3 = "xla_hlo.rng_uniform"(%0, %1, %2) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
|
||||
return %3 : tensor<2x3x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK-DAG: %[[A:.*]] = f32[] constant(0)
|
||||
// CHECK-DAG: %[[B:.*]] = f32[] constant(1)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[A]], f32[] %[[B]]), distribution=rng_uniform
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> {
|
||||
%0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
|
||||
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
scatter_dimension_numbers = {
|
||||
update_window_dims = dense<[1]> : tensor<1xi64>,
|
||||
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
index_vector_dim = 1 : i64
|
||||
},
|
||||
indices_are_sorted = true,
|
||||
unique_indices = true
|
||||
} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32>
|
||||
return %0 : tensor<200x100x300xf32>
|
||||
}
|
||||
|
||||
// CHECK: [[COMPUTATION:%.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[VAL_1:%.*]] = f32[200,100,300] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = s32[10,2] parameter(1)
|
||||
// CHECK: [[VAL_3:%.*]] = f32[10,300] parameter(2)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[10,2] [[VAL_2]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK: %[[ARG0:.*]] = pred[] parameter(0)
|
||||
// CHECK: %[[COND:.*]] = pred[2,3] broadcast(pred[] %[[ARG0]]), dimensions={}
|
||||
// CHECK: %[[ARG1:.*]] = s32[2,3] parameter(1)
|
||||
// CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2)
|
||||
|
||||
// CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]])
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "xla_hlo.select_and_scatter"(%arg0, %arg1, %0) ( {
|
||||
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
|
||||
%2 = "xla_hlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%2) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
|
||||
%2 = xla_hlo.add %arg3, %arg4 : tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
|
||||
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
|
||||
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) -> tensor<10x24x24x64xf32>
|
||||
return %1 : tensor<10x24x24x64xf32>
|
||||
}
|
||||
|
||||
// CHECK: %[[SELECT_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] {
|
||||
// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE
|
||||
|
||||
// CHECK: %[[SCATTER_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] {
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[10,24,24,64] parameter(0)
|
||||
// CHECK: %[[ARG1:.*]] = f32[10,12,12,64] parameter(1)
|
||||
// CHECK: %[[INIT:.*]] = f32[] constant(0)
|
||||
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[10,24,24,64]
|
||||
// CHECK-SAME: select-and-scatter(f32[10,24,24,64] %[[ARG0]], f32[10,12,12,64] %[[ARG1]], f32[] %[[INIT]]),
|
||||
// CHECK-SAME: window={size=1x2x2x1 stride=1x2x2x1},
|
||||
// CHECK-SAME: select=%[[SELECT_COMPUTATION]], scatter=%[[SCATTER_COMPUTATION]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: s32[1,2] slice(s32[3,4] [[ARG]]), slice={[1:2:1], [0:4:2]}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||
// CHECK: [[ARG:%.*]] = s32[1,2,3,4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] [[ARG]]), dimensions={1,0,3,2}
|
||||
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
return %0 : tensor<2x1x4x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<f32>, %arg1 : tensor<i32>) -> tuple<tensor<f32>, tensor<i32>> {
|
||||
%result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||
return %result : tuple<tensor<f32>, tensor<i32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[] parameter(0)
|
||||
// CHECK: %[[ARG1:.*]] = s32[] parameter(1)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = (f32[], s32[]) tuple(f32[] %[[ARG0]], s32[] %[[ARG1]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
// CHECK: [[ARG_F32:%.*]] = f32[4] parameter(0)
|
||||
// CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]])
|
||||
%expm1 = "xla_hlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]])
|
||||
%log1p = "xla_hlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: [[ARG_I32:%.*]] = s32[4] parameter(1)
|
||||
// CHECK: [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]])
|
||||
%not = "xla_hlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]])
|
||||
%popcnt = "xla_hlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: [[VAL_1:%.*]] = pred[4] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = pred[4] parameter(1)
|
||||
%0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1>
|
||||
// CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] {
|
||||
// CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT
|
||||
|
||||
// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) {
|
||||
// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
|
@ -1,10 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
|
||||
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
|
||||
// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}
|
@ -1,10 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: main
|
||||
// CHECK: %[[ARG0:.*]] = (f32[], s32[]) parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0
|
@ -1,11 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main() -> tensor<1x10xf32> {
|
||||
%result = "xla_hlo.iota"() {
|
||||
iota_dimension = 1 : i64
|
||||
} : () -> tensor<1x10xf32>
|
||||
return %result : tensor<1x10xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL:main
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[1,10] iota(), iota_dimension=1
|
@ -1,12 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
|
||||
%0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
|
||||
return %0 : tensor<13x19xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[ARG:%.*]] = f32[4,6] parameter(0)
|
||||
// CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1
|
@ -1,24 +0,0 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
|
||||
%result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
|
||||
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
|
||||
%fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%imax = "xla_hlo.max"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"xla_hlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
|
||||
return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK: %[[REGION:region_[0-9]+]]
|
||||
// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[])
|
||||
// CHECK: %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]])
|
||||
// CHECK: %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]])
|
||||
// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]])
|
||||
|
||||
// CHECK: ENTRY %main
|
||||
// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG0:.*]]: s32[1,10], [[ARG0:.*]]: f32[], [[ARG0:.*]]: s32[]) -> (f32[1], s32[1])
|
||||
// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %Arg_0.1, s32[1,10] %Arg_1.2, f32[] %Arg_2.3, s32[] %Arg_3.4), dimensions={1}, to_apply=%[[REGION]]
|
||||
// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0
|
||||
// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1
|
||||
// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]])
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user