Merge branch 'master' into fix_minimum_maximum
This commit is contained in:
commit
811c2a08ff
88
.bazelrc
88
.bazelrc
@ -100,9 +100,9 @@ build --apple_platform_type=macos
|
|||||||
# iOS configs for each architecture and the fat binary builds.
|
# iOS configs for each architecture and the fat binary builds.
|
||||||
build:ios --apple_platform_type=ios
|
build:ios --apple_platform_type=ios
|
||||||
build:ios --apple_bitcode=embedded --copt=-fembed-bitcode
|
build:ios --apple_bitcode=embedded --copt=-fembed-bitcode
|
||||||
|
build:ios --copt=-Wno-c++11-narrowing
|
||||||
build:ios_armv7 --config=ios
|
build:ios_armv7 --config=ios
|
||||||
build:ios_armv7 --cpu=ios_armv7
|
build:ios_armv7 --cpu=ios_armv7
|
||||||
build:ios_armv7 --copt -Wno-c++11-narrowing
|
|
||||||
build:ios_arm64 --config=ios
|
build:ios_arm64 --config=ios
|
||||||
build:ios_arm64 --cpu=ios_arm64
|
build:ios_arm64 --cpu=ios_arm64
|
||||||
build:ios_i386 --config=ios
|
build:ios_i386 --config=ios
|
||||||
@ -111,7 +111,6 @@ build:ios_x86_64 --config=ios
|
|||||||
build:ios_x86_64 --cpu=ios_x86_64
|
build:ios_x86_64 --cpu=ios_x86_64
|
||||||
build:ios_fat --config=ios
|
build:ios_fat --config=ios
|
||||||
build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64
|
build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64
|
||||||
build:ios_fat --copt -Wno-c++11-narrowing
|
|
||||||
|
|
||||||
# Config to use a mostly-static build and disable modular op registration
|
# Config to use a mostly-static build and disable modular op registration
|
||||||
# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
|
# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
|
||||||
@ -202,18 +201,25 @@ build --define=allow_oversize_protos=true
|
|||||||
build --spawn_strategy=standalone
|
build --spawn_strategy=standalone
|
||||||
build -c opt
|
build -c opt
|
||||||
|
|
||||||
# By default, build TF in C++ 14 mode.
|
|
||||||
build --cxxopt=-std=c++14
|
|
||||||
build --host_cxxopt=-std=c++14
|
|
||||||
|
|
||||||
# Make Bazel print out all options from rc files.
|
# Make Bazel print out all options from rc files.
|
||||||
build --announce_rc
|
build --announce_rc
|
||||||
|
|
||||||
# Other build flags.
|
# Other build flags.
|
||||||
build --define=grpc_no_ares=true
|
build --define=grpc_no_ares=true
|
||||||
|
|
||||||
# Prevent regression of https://github.com/bazelbuild/bazel/issues/7362
|
# See https://github.com/bazelbuild/bazel/issues/7362 for information on what
|
||||||
build --incompatible_remove_legacy_whole_archive
|
# --incompatible_remove_legacy_whole_archive flag does.
|
||||||
|
# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate
|
||||||
|
# Tensorflow to the default, however test coverage wasn't enough to catch the
|
||||||
|
# errors.
|
||||||
|
# There is ongoing work on Bazel team's side to provide support for transitive
|
||||||
|
# shared libraries. As part of migrating to transitive shared libraries, we
|
||||||
|
# hope to provide a better mechanism for control over symbol exporting, and
|
||||||
|
# then tackle this issue again.
|
||||||
|
#
|
||||||
|
# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library
|
||||||
|
# archives in -whole_archive -no_whole_archive.
|
||||||
|
build --noincompatible_remove_legacy_whole_archive
|
||||||
|
|
||||||
# Modular TF build options
|
# Modular TF build options
|
||||||
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
||||||
@ -224,13 +230,55 @@ build:c++17 --cxxopt=-std=c++1z
|
|||||||
build:c++17 --cxxopt=-stdlib=libc++
|
build:c++17 --cxxopt=-stdlib=libc++
|
||||||
build:c++1z --config=c++17
|
build:c++1z --config=c++17
|
||||||
|
|
||||||
# Default paths for TF_SYSTEM_LIBS
|
# Enable using platform specific build settings
|
||||||
build --define=PREFIX=/usr
|
build --enable_platform_specific_config
|
||||||
build --define=LIBDIR=$(PREFIX)/lib
|
|
||||||
build --define=INCLUDEDIR=$(PREFIX)/include
|
|
||||||
|
|
||||||
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
||||||
build --copt=-w
|
build:linux --copt=-w
|
||||||
|
build:macos --copt=-w
|
||||||
|
build:windows --copt=/w
|
||||||
|
|
||||||
|
# Default paths for TF_SYSTEM_LIBS
|
||||||
|
build:linux --define=PREFIX=/usr
|
||||||
|
build:linux --define=LIBDIR=$(PREFIX)/lib
|
||||||
|
build:linux --define=INCLUDEDIR=$(PREFIX)/include
|
||||||
|
build:macos --define=PREFIX=/usr
|
||||||
|
build:macos --define=LIBDIR=$(PREFIX)/lib
|
||||||
|
build:macos --define=INCLUDEDIR=$(PREFIX)/include
|
||||||
|
# TF_SYSTEM_LIBS do not work on windows.
|
||||||
|
|
||||||
|
# By default, build TF in C++ 14 mode.
|
||||||
|
build:linux --cxxopt=-std=c++14
|
||||||
|
build:linux --host_cxxopt=-std=c++14
|
||||||
|
build:macos --cxxopt=-std=c++14
|
||||||
|
build:macos --host_cxxopt=-std=c++14
|
||||||
|
build:windows --cxxopt=/std:c++14
|
||||||
|
build:windows --host_cxxopt=/std:c++14
|
||||||
|
|
||||||
|
# On windows, we still link everything into a single DLL.
|
||||||
|
build:windows --config=monolithic
|
||||||
|
|
||||||
|
# Make sure to include as little of windows.h as possible
|
||||||
|
build:windows --copt=-DWIN32_LEAN_AND_MEAN
|
||||||
|
build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
|
||||||
|
build:windows --copt=-DNOGDI
|
||||||
|
build:windows --host_copt=-DNOGDI
|
||||||
|
|
||||||
|
# Misc build options we need for windows.
|
||||||
|
build:windows --linkopt=/DEBUG
|
||||||
|
build:windows --host_linkopt=/DEBUG
|
||||||
|
build:windows --linkopt=/OPT:REF
|
||||||
|
build:windows --host_linkopt=/OPT:REF
|
||||||
|
build:windows --linkopt=/OPT:ICF
|
||||||
|
build:windows --host_linkopt=/OPT:ICF
|
||||||
|
build:windows --experimental_strict_action_env=true
|
||||||
|
build:windows --incompatible_windows_native_test_wrapper
|
||||||
|
|
||||||
|
# Verbose failure logs when something goes wrong
|
||||||
|
build:windows --verbose_failures
|
||||||
|
|
||||||
|
# On windows, we never cross compile
|
||||||
|
build:windows --distinct_host_configuration=false
|
||||||
|
|
||||||
# Suppress all warning messages.
|
# Suppress all warning messages.
|
||||||
build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
||||||
@ -335,20 +383,6 @@ build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_
|
|||||||
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
||||||
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
|
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
|
||||||
|
|
||||||
# Misc build options we need for windows
|
|
||||||
build:rbe_win --copt=-DWIN32_LEAN_AND_MEAN
|
|
||||||
build:rbe_win --host_copt=-DWIN32_LEAN_AND_MEAN
|
|
||||||
build:rbe_win --copt=-DNOGDI
|
|
||||||
build:rbe_win --host_copt=-DNOGDI
|
|
||||||
build:rbe_win --linkopt=/DEBUG
|
|
||||||
build:rbe_win --host_linkopt=/DEBUG
|
|
||||||
build:rbe_win --linkopt=/OPT:REF
|
|
||||||
build:rbe_win --host_linkopt=/OPT:REF
|
|
||||||
build:rbe_win --linkopt=/OPT:ICF
|
|
||||||
build:rbe_win --host_linkopt=/OPT:ICF
|
|
||||||
build:rbe_win --config=monolithic
|
|
||||||
build:rbe_win --experimental_strict_action_env=true
|
|
||||||
build:rbe_win --incompatible_windows_native_test_wrapper
|
|
||||||
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
|
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
|
||||||
build:rbe_win --define=override_eigen_strong_inline=true
|
build:rbe_win --define=override_eigen_strong_inline=true
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ swift_rules_dependencies()
|
|||||||
# files, in case the parsing of those build files depends on the bazel
|
# files, in case the parsing of those build files depends on the bazel
|
||||||
# version we require here.
|
# version we require here.
|
||||||
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
|
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
|
||||||
check_bazel_version_at_least("0.19.0")
|
check_bazel_version_at_least("1.0.0")
|
||||||
|
|
||||||
load("//third_party/android:android_configure.bzl", "android_configure")
|
load("//third_party/android:android_configure.bzl", "android_configure")
|
||||||
android_configure(name="local_config_android")
|
android_configure(name="local_config_android")
|
||||||
|
16
configure.py
16
configure.py
@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
|||||||
_TF_WORKSPACE_ROOT = ''
|
_TF_WORKSPACE_ROOT = ''
|
||||||
_TF_BAZELRC = ''
|
_TF_BAZELRC = ''
|
||||||
_TF_CURRENT_BAZEL_VERSION = None
|
_TF_CURRENT_BAZEL_VERSION = None
|
||||||
_TF_MIN_BAZEL_VERSION = '0.27.1'
|
_TF_MIN_BAZEL_VERSION = '1.0.0'
|
||||||
_TF_MAX_BAZEL_VERSION = '1.1.0'
|
_TF_MAX_BAZEL_VERSION = '1.1.0'
|
||||||
|
|
||||||
NCCL_LIB_PATHS = [
|
NCCL_LIB_PATHS = [
|
||||||
@ -1232,20 +1232,6 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
|
|||||||
|
|
||||||
def set_windows_build_flags(environ_cp):
|
def set_windows_build_flags(environ_cp):
|
||||||
"""Set Windows specific build options."""
|
"""Set Windows specific build options."""
|
||||||
# The non-monolithic build is not supported yet
|
|
||||||
write_to_bazelrc('build --config monolithic')
|
|
||||||
# Suppress warning messages
|
|
||||||
write_to_bazelrc('build --copt=-w --host_copt=-w')
|
|
||||||
# Fix winsock2.h conflicts
|
|
||||||
write_to_bazelrc(
|
|
||||||
'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN '
|
|
||||||
'--copt=-DNOGDI --host_copt=-DNOGDI')
|
|
||||||
# Output more verbose information when something goes wrong
|
|
||||||
write_to_bazelrc('build --verbose_failures')
|
|
||||||
# The host and target platforms are the same in Windows build. So we don't
|
|
||||||
# have to distinct them. This avoids building the same targets twice.
|
|
||||||
write_to_bazelrc('build --distinct_host_configuration=false')
|
|
||||||
|
|
||||||
if is_reduced_optimize_huge_functions_available(environ_cp):
|
if is_reduced_optimize_huge_functions_available(environ_cp):
|
||||||
write_to_bazelrc(
|
write_to_bazelrc(
|
||||||
'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions'
|
'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions'
|
||||||
|
@ -2,10 +2,7 @@
|
|||||||
# TensorFlow is a computational framework, primarily for use in machine
|
# TensorFlow is a computational framework, primarily for use in machine
|
||||||
# learning applications.
|
# learning applications.
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "VERSION")
|
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl")
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow/core/platform:build_config.bzl",
|
"//tensorflow/core/platform:build_config.bzl",
|
||||||
"tf_additional_binary_deps",
|
"tf_additional_binary_deps",
|
||||||
@ -450,6 +447,7 @@ config_setting(
|
|||||||
package_group(
|
package_group(
|
||||||
name = "internal",
|
name = "internal",
|
||||||
packages = [
|
packages = [
|
||||||
|
"//learning/brain/swift/x10/...",
|
||||||
"//perftools/accelerators/xprof/api/...",
|
"//perftools/accelerators/xprof/api/...",
|
||||||
"//tensorflow/...",
|
"//tensorflow/...",
|
||||||
"//tensorflow_estimator/python/estimator/...",
|
"//tensorflow_estimator/python/estimator/...",
|
||||||
|
@ -119,11 +119,11 @@ def _running_from_pip_package():
|
|||||||
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
||||||
|
|
||||||
if _running_from_pip_package():
|
if _running_from_pip_package():
|
||||||
for s in _site_packages_dirs:
|
for _s in _site_packages_dirs:
|
||||||
# TODO(gunan): Add sanity checks to loaded modules here.
|
# TODO(gunan): Add sanity checks to loaded modules here.
|
||||||
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
|
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||||
if _fi.file_exists(plugin_dir):
|
if _fi.file_exists(_plugin_dir):
|
||||||
_ll.load_library(plugin_dir)
|
_ll.load_library(_plugin_dir)
|
||||||
|
|
||||||
# Add module aliases
|
# Add module aliases
|
||||||
if hasattr(_current_module, 'keras'):
|
if hasattr(_current_module, 'keras'):
|
||||||
@ -136,3 +136,5 @@ if hasattr(_current_module, 'keras'):
|
|||||||
setattr(_current_module, "optimizers", optimizers)
|
setattr(_current_module, "optimizers", optimizers)
|
||||||
setattr(_current_module, "initializers", initializers)
|
setattr(_current_module, "initializers", initializers)
|
||||||
# pylint: enable=undefined-variable
|
# pylint: enable=undefined-variable
|
||||||
|
|
||||||
|
# __all__ PLACEHOLDER
|
||||||
|
@ -132,9 +132,10 @@ def _running_from_pip_package():
|
|||||||
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
||||||
|
|
||||||
if _running_from_pip_package():
|
if _running_from_pip_package():
|
||||||
for s in _site_packages_dirs:
|
for _s in _site_packages_dirs:
|
||||||
# TODO(gunan): Add sanity checks to loaded modules here.
|
# TODO(gunan): Add sanity checks to loaded modules here.
|
||||||
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
|
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||||
if _fi.file_exists(plugin_dir):
|
if _fi.file_exists(_plugin_dir):
|
||||||
_ll.load_library(plugin_dir)
|
_ll.load_library(_plugin_dir)
|
||||||
|
|
||||||
|
# __all__ PLACEHOLDER
|
||||||
|
@ -196,6 +196,12 @@ cc_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_status_headers",
|
||||||
|
hdrs = ["tf_status.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tf_file_statistics",
|
name = "tf_file_statistics",
|
||||||
hdrs = ["tf_file_statistics.h"],
|
hdrs = ["tf_file_statistics.h"],
|
||||||
|
@ -231,7 +231,14 @@ cc_library(
|
|||||||
srcs = ["shape_inference_helpers.cc"],
|
srcs = ["shape_inference_helpers.cc"],
|
||||||
hdrs = ["shape_inference_helpers.h"],
|
hdrs = ["shape_inference_helpers.h"],
|
||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
deps = ["//tensorflow/core:graph"],
|
deps = select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"//tensorflow/core:android_tensorflow_lib",
|
||||||
|
],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
],
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Internal targets below this point.
|
# Internal targets below this point.
|
||||||
|
@ -880,11 +880,8 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
|||||||
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
||||||
|
|
||||||
// Parses output_arrays_order from command line option.
|
// Parses output_arrays_order from command line option.
|
||||||
absl::flat_hash_set<std::string> output_set;
|
std::vector<std::string> outputs;
|
||||||
std::vector<std::string> output_arrays_order;
|
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) {
|
||||||
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &output_set,
|
|
||||||
&output_arrays_order)
|
|
||||||
.ok()) {
|
|
||||||
return emitError(loc, "parsing output array info failed ")
|
return emitError(loc, "parsing output array info failed ")
|
||||||
<< output_arrays_string,
|
<< output_arrays_string,
|
||||||
nullptr;
|
nullptr;
|
||||||
@ -892,7 +889,7 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
|||||||
|
|
||||||
return tflite::FlatBufferToMlir(
|
return tflite::FlatBufferToMlir(
|
||||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||||
context, loc, output_arrays_order);
|
context, loc, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
||||||
|
@ -384,7 +384,7 @@ class Translator {
|
|||||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
|
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
|
||||||
|
|
||||||
// Returns opcode index for op identified by the op_name, if already
|
// Returns opcode index for op identified by the op_name, if already
|
||||||
// available. Otherwise, creates a new OperactorCode using the given `builtin`
|
// available. Otherwise, creates a new OperatorCode using the given `builtin`
|
||||||
// operator and associates it with `op_name`.
|
// operator and associates it with `op_name`.
|
||||||
uint32_t GetOpcodeIndex(const std::string& op_name,
|
uint32_t GetOpcodeIndex(const std::string& op_name,
|
||||||
tflite::BuiltinOperator builtin);
|
tflite::BuiltinOperator builtin);
|
||||||
|
@ -720,7 +720,8 @@ static LogicalResult Verify(PackOp op) {
|
|||||||
for (Value *operand : op.getOperands()) {
|
for (Value *operand : op.getOperands()) {
|
||||||
auto other_type = operand->getType().cast<ShapedType>();
|
auto other_type = operand->getType().cast<ShapedType>();
|
||||||
if (input_type != other_type)
|
if (input_type != other_type)
|
||||||
return op.emitOpError("operands should be of the same type");
|
return op.emitOpError("operands should be of the same type. got ")
|
||||||
|
<< input_type << ", " << other_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@ -857,8 +858,8 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
//
|
//
|
||||||
// => Value [5, 8, 9]
|
// => Value [5, 8, 9]
|
||||||
// TODO(b/133341698): Move to tablegen when variadic is supported.
|
// TODO(b/133341698): Move to tablegen when variadic is supported.
|
||||||
struct RemoveRedunantUnpackPack : public RewritePattern {
|
struct RemoveRedundantUnpackPack : public RewritePattern {
|
||||||
explicit RemoveRedunantUnpackPack(MLIRContext *context)
|
explicit RemoveRedundantUnpackPack(MLIRContext *context)
|
||||||
: RewritePattern(PackOp::getOperationName(), 2, context) {}
|
: RewritePattern(PackOp::getOperationName(), 2, context) {}
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
@ -896,7 +897,7 @@ struct RemoveRedunantUnpackPack : public RewritePattern {
|
|||||||
|
|
||||||
void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.insert<RemoveRedunantUnpackPack>(context);
|
results.insert<RemoveRedundantUnpackPack>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1041,7 +1042,7 @@ struct DropFakeQuant : public RewritePattern {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||||
// Replace the matched FakeQuantOp by its primiary operand.
|
// Replace the matched FakeQuantOp by its primary operand.
|
||||||
rewriter.replaceOp(op, op->getOperand(0));
|
rewriter.replaceOp(op, op->getOperand(0));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -32,7 +32,7 @@ def TFL_Dialect : Dialect {
|
|||||||
Invariants:
|
Invariants:
|
||||||
|
|
||||||
* All values are of Tensor type (in particular, scalars are
|
* All values are of Tensor type (in particular, scalars are
|
||||||
represented using zero-dimentional tensors);
|
represented using zero-dimensional tensors);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let cppNamespace = "TFL";
|
let cppNamespace = "TFL";
|
||||||
@ -603,7 +603,7 @@ def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
|
|||||||
def TFL_FCWO_Shuffled4x16i8 : StrEnumAttrCase<"SHUFFLED4x16INT8">;
|
def TFL_FCWO_Shuffled4x16i8 : StrEnumAttrCase<"SHUFFLED4x16INT8">;
|
||||||
|
|
||||||
def TFL_FullyConnectedOptionsWeightFormatAttr :
|
def TFL_FullyConnectedOptionsWeightFormatAttr :
|
||||||
StrEnumAttr<"FullyConectedOptionsWeightsFormat",
|
StrEnumAttr<"FullyConnectedOptionsWeightsFormat",
|
||||||
"fully connected options weights format", [
|
"fully connected options weights format", [
|
||||||
TFL_FCWO_Default, TFL_FCWO_Shuffled4x16i8
|
TFL_FCWO_Default, TFL_FCWO_Shuffled4x16i8
|
||||||
]>;
|
]>;
|
||||||
@ -1873,9 +1873,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
|||||||
x -> max(0, x)
|
x -> max(0, x)
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyTensor:$x);
|
let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
|
||||||
|
|
||||||
let results = (outs AnyTensor:$y);
|
let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||||
@ -1888,9 +1888,24 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
|||||||
x -> max(0, min(6, x))
|
x -> max(0, min(6, x))
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyTensor:$x);
|
let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
|
||||||
|
|
||||||
let results = (outs AnyTensor:$y);
|
let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
|
||||||
|
SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultsScale]> {
|
||||||
|
let summary = "Relu1 operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Element-wise Relu1 operator
|
||||||
|
x -> max(-1, min(1, x))
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
|
||||||
|
|
||||||
|
let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_ReshapeOp: TFL_Op<"reshape", [
|
def TFL_ReshapeOp: TFL_Op<"reshape", [
|
||||||
|
@ -237,8 +237,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
// Parse output arrays.
|
// Parse output arrays.
|
||||||
std::vector<string> output_arrays(model_flags.output_arrays().begin(),
|
std::vector<string> output_arrays(model_flags.output_arrays().begin(),
|
||||||
model_flags.output_arrays().end());
|
model_flags.output_arrays().end());
|
||||||
TF_RETURN_IF_ERROR(tensorflow::ParseOutputArrayInfo(
|
TF_RETURN_IF_ERROR(
|
||||||
output_arrays, &specs.output_arrays, &specs.output_arrays_order));
|
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
|
||||||
|
|
||||||
// Other flags.
|
// Other flags.
|
||||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||||
|
@ -82,7 +82,7 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
|
|||||||
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// A method to retrive the name for the given op.
|
// A method to retrieve the name for the given op.
|
||||||
OperationToName op_to_name_;
|
OperationToName op_to_name_;
|
||||||
|
|
||||||
// We split the normal names and regex names, since the former can use hash
|
// We split the normal names and regex names, since the former can use hash
|
||||||
|
@ -102,7 +102,7 @@ class AffineOpCoefficient<int dim, int index> : NativeOpTrait<
|
|||||||
StrJoinInt<[dim, index]>.result,
|
StrJoinInt<[dim, index]>.result,
|
||||||
">::Impl")>;
|
">::Impl")>;
|
||||||
|
|
||||||
// Specify this trait if the op doesn't have quantizable ouput. We shouldn't
|
// Specify this trait if the op doesn't have quantizable output. We shouldn't
|
||||||
// apply quantization on this op.
|
// apply quantization on this op.
|
||||||
def NoQuantizableResult : NativeOpTrait<"quant::NoQuantizableResult">;
|
def NoQuantizableResult : NativeOpTrait<"quant::NoQuantizableResult">;
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ class SameOperandsAndResultsScale
|
|||||||
// OpTrait::quant::FixedResultUniformScale<
|
// OpTrait::quant::FixedResultUniformScale<
|
||||||
// 8, -128, 390625, -8, 0, 255, false>::Impl> {
|
// 8, -128, 390625, -8, 0, 255, false>::Impl> {
|
||||||
//
|
//
|
||||||
// TODO(fengliuai): create a better way to epxress floating point scale in the
|
// TODO(fengliuai): create a better way to express floating point scale in the
|
||||||
// template argument list.
|
// template argument list.
|
||||||
template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp,
|
template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp,
|
||||||
int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign>
|
int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign>
|
||||||
|
@ -133,10 +133,10 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
|||||||
// quantization parameters are annotated by the Q/DQ op pairs. Each
|
// quantization parameters are annotated by the Q/DQ op pairs. Each
|
||||||
// matched pattern are rewritten by its quantized alternatives.
|
// matched pattern are rewritten by its quantized alternatives.
|
||||||
//
|
//
|
||||||
// The concret pattern, extends from this base pattern, can specify whether it
|
// The concrete pattern, extends from this base pattern, can specify whether it
|
||||||
// allows "hybrid" operands or results. These "hybrid" operands and results
|
// allows "hybrid" operands or results. These "hybrid" operands and results
|
||||||
// don't have quantization parameters propagated to, so will be in float in the
|
// don't have quantization parameters propagated to, so will be in float in the
|
||||||
// quantized results. The concret pattern should define the following two
|
// quantized results. The concrete pattern should define the following two
|
||||||
// functions:
|
// functions:
|
||||||
//
|
//
|
||||||
// bool AllowHybridOperand() const
|
// bool AllowHybridOperand() const
|
||||||
|
@ -114,8 +114,8 @@ func @fakequant_notdropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @RemoveRedunantUnpackPack
|
// CHECK-LABEL: @RemoveRedundantUnpackPack
|
||||||
func @RemoveRedunantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
func @RemoveRedundantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
||||||
%0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
|
%0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
|
||||||
%1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>)
|
%1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>)
|
||||||
return %1: tensor<2x5xf32>
|
return %1: tensor<2x5xf32>
|
||||||
@ -125,8 +125,8 @@ func @RemoveRedunantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @RemoveRedunantPack
|
// CHECK-LABEL: @RemoveRedundantPack
|
||||||
func @RemoveRedunantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5xf32>) {
|
func @RemoveRedundantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5xf32>) {
|
||||||
%0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
|
%0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
|
||||||
%1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>)
|
%1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>)
|
||||||
return %1, %0#0: tensor<2x5xf32>, tensor<5xf32>
|
return %1, %0#0: tensor<2x5xf32>, tensor<5xf32>
|
||||||
|
@ -106,8 +106,8 @@ func @extractStackInputOutputOphint() {
|
|||||||
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||||
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
|
||||||
// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
|
||||||
// CHECK-DAG: %[[OUPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
// CHECK-DAG: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||||
// CHECK-DAG: %[[OUPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
// CHECK-DAG: %[[OUTPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||||
|
|
||||||
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
|
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
|
||||||
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
func @addRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
|
func @addRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||||
%1 = "tf.Add"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
%1 = "tf.Add"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||||
%2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
|
%2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
%3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
%3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
%4 = "tf.Add"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
%4 = "tf.Add"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||||
%5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
|
%5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
%6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
%6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||||
%7 = "tf.Relu6"(%6) : (tensor<1xi32>) -> tensor<1xi32>
|
%7 = "tf.Relu6"(%6) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
return %7: tensor<1xi32>
|
return %7: tensor<1xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: addRelu
|
// CHECK-LABEL: addRelu
|
||||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
|
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
|
||||||
// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
|
// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
|
||||||
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
|
// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
|
||||||
// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xi32>
|
// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xf32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,32 +244,32 @@ func @zeros_like(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
|||||||
// CHECK: "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
// CHECK: "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @divRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
|
func @divRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||||
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||||
%1 = "tf.Div"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
%1 = "tf.Div"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||||
%2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
|
%2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
%3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
%3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
%4 = "tf.Div"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
%4 = "tf.Div"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||||
%5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
|
%5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
return %5: tensor<1xi32>
|
return %5: tensor<1xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: divRelu
|
// CHECK-LABEL: divRelu
|
||||||
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
|
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
|
||||||
// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
|
// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
|
||||||
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
|
// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
}
|
}
|
||||||
|
|
||||||
func @squaredDifferenceRelu(tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> {
|
func @squaredDifferenceRelu(tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> {
|
||||||
^bb0(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>):
|
^bb0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>):
|
||||||
%0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
%0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||||
%1 = "tf.Relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
|
%1 = "tf.Relu6"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
return %1: tensor<1xi32>
|
return %1: tensor<1xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: squaredDifferenceRelu
|
// CHECK-LABEL: squaredDifferenceRelu
|
||||||
// CHECK: tfl.squared_difference %arg0, %arg1 : tensor<1xi32>
|
// CHECK: tfl.squared_difference %arg0, %arg1 : tensor<1xf32>
|
||||||
// CHECK: %1 = "tfl.relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
|
// CHECK: %1 = "tfl.relu6"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -434,7 +434,7 @@ func @testEluI32(%arg0: tensor<? x i32>) -> tensor<? x i32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
func @testFusedActivationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||||
// CHECK: "NONE"
|
// CHECK: "NONE"
|
||||||
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4xi32>
|
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4xi32>
|
||||||
// CHECK: "RELU"
|
// CHECK: "RELU"
|
||||||
@ -452,7 +452,7 @@ func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
func @testFusedActivationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||||
// expected-error @+1 {{attribute 'fused_activation_function' failed to satisfy constraint: fused activation enum}}
|
// expected-error @+1 {{attribute 'fused_activation_function' failed to satisfy constraint: fused activation enum}}
|
||||||
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "Relu6"} : tensor<4xi32>
|
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "Relu6"} : tensor<4xi32>
|
||||||
return %0: tensor<4xi32>
|
return %0: tensor<4xi32>
|
||||||
@ -1047,7 +1047,7 @@ func @testConcatInvalidOperandDimSize(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3x
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testConcatInvalidOperandDimSizeComaredToPrevInput(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3xi32>) -> tensor<?x?xi32> {
|
func @testConcatInvalidOperandDimSizeComparedToPrevInput(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3xi32>) -> tensor<?x?xi32> {
|
||||||
// expected-error @+1 {{'tfl.concatenation' op dimension size of dimension #1 of operand #1 must be equal to dimension size of dimension #1 of operand #0, expected 2, got 3}}
|
// expected-error @+1 {{'tfl.concatenation' op dimension size of dimension #1 of operand #1 must be equal to dimension size of dimension #1 of operand #0, expected 2, got 3}}
|
||||||
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x3xi32>) -> tensor<?x?xi32>
|
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x3xi32>) -> tensor<?x?xi32>
|
||||||
return %0 : tensor<?x?xi32>
|
return %0 : tensor<?x?xi32>
|
||||||
|
@ -278,7 +278,7 @@ func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x
|
|||||||
%cst2 = constant dense<3.0> : tensor<112x2xf32>
|
%cst2 = constant dense<3.0> : tensor<112x2xf32>
|
||||||
|
|
||||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||||
// We cannot fuse this tfl.mul into the preceding conv op becuase %cst2 is not broadcast-compatible to %cst0.
|
// We cannot fuse this tfl.mul into the preceding conv op because %cst2 is not broadcast-compatible to %cst0.
|
||||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
|
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
|
||||||
|
|
||||||
return %1 : tensor<1x112x112x2xf32>
|
return %1 : tensor<1x112x112x2xf32>
|
||||||
@ -600,3 +600,25 @@ func @squeezeToReshape(%arg0: tensor<1x1x2xf32>) -> tensor<2xf32> {
|
|||||||
// CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor<1x1x2xf32>, tensor<1xi32>) -> tensor<2xf32>
|
// CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor<1x1x2xf32>, tensor<1xi32>) -> tensor<2xf32>
|
||||||
// CHECK: return %[[RESULT]]
|
// CHECK: return %[[RESULT]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: Relu1
|
||||||
|
func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||||
|
%cst = constant dense<-1.0> : tensor<f32>
|
||||||
|
%cst1 = constant dense<1.0> : tensor<f32>
|
||||||
|
%0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||||
|
%1 = "tfl.minimum"(%0, %cst1) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||||
|
return %1 : tensor<2x3xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: Relu1_2
|
||||||
|
func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||||
|
%cst = constant dense<-1.0> : tensor<f32>
|
||||||
|
%cst1 = constant dense<1.0> : tensor<f32>
|
||||||
|
%0 = "tfl.minimum"(%arg0, %cst1) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||||
|
%1 = "tfl.maximum"(%0, %cst) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||||
|
return %1 : tensor<2x3xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
|
||||||
|
}
|
||||||
|
@ -30,7 +30,7 @@ opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
|
|||||||
opt<bool> use_splatted_constant(
|
opt<bool> use_splatted_constant(
|
||||||
"use-splatted-constant",
|
"use-splatted-constant",
|
||||||
llvm::cl::desc(
|
llvm::cl::desc(
|
||||||
"Replace constants with randonmly generated splatted tensors"),
|
"Replace constants with randomly generated splatted tensors"),
|
||||||
llvm::cl::init(false), llvm::cl::Hidden);
|
llvm::cl::init(false), llvm::cl::Hidden);
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
opt<bool> input_mlir(
|
opt<bool> input_mlir(
|
||||||
|
@ -282,7 +282,7 @@ struct OphintCompositeOp {
|
|||||||
// Since we have different aggregation strategies, e.g., "first", "last",
|
// Since we have different aggregation strategies, e.g., "first", "last",
|
||||||
// "stack". We don't somehow aggregated to get the outputs for the funcOp.
|
// "stack". We don't somehow aggregated to get the outputs for the funcOp.
|
||||||
// This function is simply compute the RankedTensorType (shape & element type)
|
// This function is simply compute the RankedTensorType (shape & element type)
|
||||||
std::map<int, Type> GetAggregatedOuputTypes(OpBuilder* builder) {
|
std::map<int, Type> GetAggregatedOutputTypes(OpBuilder* builder) {
|
||||||
std::map<int, Type> aggregated_output_types;
|
std::map<int, Type> aggregated_output_types;
|
||||||
for (const auto& kv : outputs) {
|
for (const auto& kv : outputs) {
|
||||||
const AggregatedOperand& operand = kv.second;
|
const AggregatedOperand& operand = kv.second;
|
||||||
@ -387,11 +387,12 @@ struct OphintCompositeOp {
|
|||||||
// inputs/outputs indicate edges) Assume the graph is acyclic. The preprocess
|
// inputs/outputs indicate edges) Assume the graph is acyclic. The preprocess
|
||||||
// does the following:
|
// does the following:
|
||||||
// Compute each operations's in-degress (how many input nodes they're taken)
|
// Compute each operations's in-degress (how many input nodes they're taken)
|
||||||
// Get all consumer operations for every operations. (operation_to_ouputs)
|
// Get all consumer operations for every operations. (operation_to_outputs)
|
||||||
// Get the init_queue (those operations will be processed first).
|
// Get the init_queue (those operations will be processed first).
|
||||||
void PreprocessTopoSortGraph(
|
void PreprocessTopoSortGraph(
|
||||||
Block* block, std::queue<Operation*>* init_queue,
|
Block* block, std::queue<Operation*>* init_queue,
|
||||||
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>>* operation_to_ouputs,
|
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>>*
|
||||||
|
operation_to_outputs,
|
||||||
llvm::DenseMap<Operation*, int>* operation_to_in_degrees) {
|
llvm::DenseMap<Operation*, int>* operation_to_in_degrees) {
|
||||||
for (auto& op : *block) {
|
for (auto& op : *block) {
|
||||||
if (&op == block->getTerminator()) continue;
|
if (&op == block->getTerminator()) continue;
|
||||||
@ -412,9 +413,9 @@ void PreprocessTopoSortGraph(
|
|||||||
}
|
}
|
||||||
operation_to_in_degrees->try_emplace(&op, input_ops.size());
|
operation_to_in_degrees->try_emplace(&op, input_ops.size());
|
||||||
for (auto* input_op : input_ops) {
|
for (auto* input_op : input_ops) {
|
||||||
auto preceeding_op_it = operation_to_ouputs->find(input_op);
|
auto preceeding_op_it = operation_to_outputs->find(input_op);
|
||||||
if (preceeding_op_it == operation_to_ouputs->end()) {
|
if (preceeding_op_it == operation_to_outputs->end()) {
|
||||||
auto result = operation_to_ouputs->try_emplace(
|
auto result = operation_to_outputs->try_emplace(
|
||||||
input_op, llvm::DenseSet<Operation*>());
|
input_op, llvm::DenseSet<Operation*>());
|
||||||
preceeding_op_it = result.first;
|
preceeding_op_it = result.first;
|
||||||
}
|
}
|
||||||
@ -442,19 +443,19 @@ bool IsSideEffectOp(Operation* op) {
|
|||||||
// Also assume the block has no arguments.
|
// Also assume the block has no arguments.
|
||||||
LogicalResult TopoSortOperations(OpBuilder* builder) {
|
LogicalResult TopoSortOperations(OpBuilder* builder) {
|
||||||
std::queue<Operation*> init_queue;
|
std::queue<Operation*> init_queue;
|
||||||
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>> operation_to_ouputs;
|
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>> operation_to_outputs;
|
||||||
llvm::DenseMap<Operation*, int> operation_to_in_degrees;
|
llvm::DenseMap<Operation*, int> operation_to_in_degrees;
|
||||||
std::vector<Operation*> sorted_ops;
|
std::vector<Operation*> sorted_ops;
|
||||||
|
|
||||||
PreprocessTopoSortGraph(builder->getBlock(), &init_queue,
|
PreprocessTopoSortGraph(builder->getBlock(), &init_queue,
|
||||||
&operation_to_ouputs, &operation_to_in_degrees);
|
&operation_to_outputs, &operation_to_in_degrees);
|
||||||
while (!init_queue.empty()) {
|
while (!init_queue.empty()) {
|
||||||
Operation* current_op = init_queue.front();
|
Operation* current_op = init_queue.front();
|
||||||
init_queue.pop();
|
init_queue.pop();
|
||||||
sorted_ops.push_back(current_op);
|
sorted_ops.push_back(current_op);
|
||||||
|
|
||||||
auto current_op_to_output_it = operation_to_ouputs.find(current_op);
|
auto current_op_to_output_it = operation_to_outputs.find(current_op);
|
||||||
if (current_op_to_output_it == operation_to_ouputs.end()) {
|
if (current_op_to_output_it == operation_to_outputs.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (Operation* output_op : current_op_to_output_it->second) {
|
for (Operation* output_op : current_op_to_output_it->second) {
|
||||||
@ -467,7 +468,7 @@ LogicalResult TopoSortOperations(OpBuilder* builder) {
|
|||||||
operation_to_in_degrees.erase(output_op_it);
|
operation_to_in_degrees.erase(output_op_it);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
operation_to_ouputs.erase(current_op_to_output_it);
|
operation_to_outputs.erase(current_op_to_output_it);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Before we performs the sort. We need to make sure we didn't mess the
|
// Before we performs the sort. We need to make sure we didn't mess the
|
||||||
@ -629,11 +630,11 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
|||||||
|
|
||||||
// Step 4, get aggregated output types.
|
// Step 4, get aggregated output types.
|
||||||
const std::map<int, Type>& aggregated_output_types =
|
const std::map<int, Type>& aggregated_output_types =
|
||||||
ophint_composite_op.GetAggregatedOuputTypes(builder);
|
ophint_composite_op.GetAggregatedOutputTypes(builder);
|
||||||
|
|
||||||
// Step 5, create & place the fused op and rewire the inputs.
|
// Step 5, create & place the fused op and rewire the inputs.
|
||||||
// Here we use a funcOp to represent the fused op. This "funcOp" will be
|
// Here we use a funcOp to represent the fused op. This "funcOp" will be
|
||||||
// coonverted to other ops (like UnidirectionalSequenceRNNOp) in the
|
// converted to other ops (like UnidirectionalSequenceRNNOp) in the
|
||||||
// legalization phase.
|
// legalization phase.
|
||||||
Operation* inserted_before_op = ophint_composite_op.GetFirstOutputOp();
|
Operation* inserted_before_op = ophint_composite_op.GetFirstOutputOp();
|
||||||
Operation* fused_op = BuildFusedFuncOp(
|
Operation* fused_op = BuildFusedFuncOp(
|
||||||
|
@ -191,7 +191,7 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ConvertTfLiteFusedOpIfAvaiable(StringRef func_name,
|
LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name,
|
||||||
FuncOp composite_func_op,
|
FuncOp composite_func_op,
|
||||||
CallOp call_op,
|
CallOp call_op,
|
||||||
OpBuilder* builder) {
|
OpBuilder* builder) {
|
||||||
@ -243,7 +243,7 @@ LogicalResult ConvertCallOps(llvm::StringMap<FuncOp>* composite_func_ops,
|
|||||||
StringRef func_name = composite_func_op.getAttr(kTfLiteFunctionName)
|
StringRef func_name = composite_func_op.getAttr(kTfLiteFunctionName)
|
||||||
.cast<StringAttr>()
|
.cast<StringAttr>()
|
||||||
.getValue();
|
.getValue();
|
||||||
if (failed(ConvertTfLiteFusedOpIfAvaiable(func_name, composite_func_op,
|
if (failed(ConvertTfLiteFusedOpIfAvailable(func_name, composite_func_op,
|
||||||
call_op, &builder)))
|
call_op, &builder)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -140,6 +140,8 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
|
|||||||
return ExpandTo4DForConvImpl(a, true);
|
return ExpandTo4DForConvImpl(a, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns shape of a ranked tensor.
|
||||||
|
// Precondition: output_val's is ranked tensor.
|
||||||
DenseElementsAttr GetShape(Value *output_val) {
|
DenseElementsAttr GetShape(Value *output_val) {
|
||||||
auto output_type = output_val->getType().cast<RankedTensorType>();
|
auto output_type = output_val->getType().cast<RankedTensorType>();
|
||||||
auto shape_vector = output_type.getShape();
|
auto shape_vector = output_type.getShape();
|
||||||
|
@ -267,9 +267,27 @@ multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
|
|||||||
foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
|
foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
|
||||||
in defm : FuseTileBroadcastIntoFollowingBinary<BroadcastingOp>;
|
in defm : FuseTileBroadcastIntoFollowingBinary<BroadcastingOp>;
|
||||||
|
|
||||||
|
// Returns shape of a ranked tensor.
|
||||||
|
// if called without a ranked tensor it will fail.
|
||||||
def GetShape: NativeCodeCall<"GetShape($0)">;
|
def GetShape: NativeCodeCall<"GetShape($0)">;
|
||||||
|
|
||||||
def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
|
def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
|
||||||
(TFL_ReshapeOp $input,
|
(TFL_ReshapeOp $input,
|
||||||
(ConstantOp (GetShape $squeeze_op))),
|
(ConstantOp (GetShape $squeeze_op))),
|
||||||
[(AnyStaticShapeTensor $squeeze_op)]>;
|
[(AnyStaticShapeTensor $squeeze_op)]>;
|
||||||
|
|
||||||
|
class ValueEquals<string val> : Constraint<CPred<
|
||||||
|
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
|
||||||
|
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
|
||||||
|
|
||||||
|
def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
|
||||||
|
(ConstantOp $NegOne)),
|
||||||
|
(ConstantOp $One)),
|
||||||
|
(TFL_Relu1Op $input),
|
||||||
|
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||||
|
|
||||||
|
def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
||||||
|
(ConstantOp $One)),
|
||||||
|
(ConstantOp $NegOne)),
|
||||||
|
(TFL_Relu1Op $input),
|
||||||
|
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||||
|
@ -68,11 +68,11 @@ class ConvertEmbeddedLookupFunc {
|
|||||||
if (func_.getNumArguments() != 2) {
|
if (func_.getNumArguments() != 2) {
|
||||||
return func_.emitError()
|
return func_.emitError()
|
||||||
<< "Invalid number of arguments in the embedding "
|
<< "Invalid number of arguments in the embedding "
|
||||||
"matmal composite function";
|
"matmul composite function";
|
||||||
}
|
}
|
||||||
if (func_.getType().getNumResults() != 1) {
|
if (func_.getType().getNumResults() != 1) {
|
||||||
return func_.emitError() << "Invalid number of results in the embedding "
|
return func_.emitError() << "Invalid number of results in the embedding "
|
||||||
"matmal composite function";
|
"matmul composite function";
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ limitations under the License.
|
|||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static llvm::cl::list<std::string> quantize_whitelist(
|
static llvm::cl::list<std::string> quantize_whitelist(
|
||||||
"tfl-test-quantize-whitelist", llvm::cl::value_desc("list"),
|
"tfl-test-quantize-whitelist", llvm::cl::value_desc("list"),
|
||||||
llvm::cl::desc("comma seprarated list of whitelisted functions to be "
|
llvm::cl::desc("comma separated list of whitelisted functions to be "
|
||||||
"quantized. Only used in tests"),
|
"quantized. Only used in tests"),
|
||||||
llvm::cl::CommaSeparated);
|
llvm::cl::CommaSeparated);
|
||||||
|
|
||||||
|
@ -400,7 +400,7 @@ class ConvertTFDepthwiseConv2dNative
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// StridedSlice can have complicated atributes like begin_axis_mask,
|
// StridedSlice can have complicated attributes like begin_axis_mask,
|
||||||
// end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
|
// end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
|
||||||
// masks will complicate the strided_slice computation logic, we can simplify
|
// masks will complicate the strided_slice computation logic, we can simplify
|
||||||
// the logic by inserting a reshape op to pad the inputs so strided_slice can
|
// the logic by inserting a reshape op to pad the inputs so strided_slice can
|
||||||
|
@ -247,7 +247,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
|
if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
|
||||||
// Input dimensions must be compatible for multipication.
|
// Input dimensions must be compatible for multiplication.
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ class ConvertLSTMCellSimpleToFusedLSTM {
|
|||||||
Value* input2cell_;
|
Value* input2cell_;
|
||||||
Value* input2output_;
|
Value* input2output_;
|
||||||
|
|
||||||
// reccurrent -> cifg
|
// recurrent -> cifg
|
||||||
Value* rec2input_;
|
Value* rec2input_;
|
||||||
Value* rec2forget_;
|
Value* rec2forget_;
|
||||||
Value* rec2cell_;
|
Value* rec2cell_;
|
||||||
|
@ -28,7 +28,7 @@ config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm')
|
|||||||
config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR'])
|
config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR'])
|
||||||
config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'],
|
config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'],
|
||||||
'local_config_mlir')
|
'local_config_mlir')
|
||||||
# TODO(jpienaar): Replace with sufffices in build rule.
|
# TODO(jpienaar): Replace with suffices in build rule.
|
||||||
config.suffixes = ['.td', '.mlir', '.pbtxt']
|
config.suffixes = ['.td', '.mlir', '.pbtxt']
|
||||||
|
|
||||||
mlir_tf_tools_dirs = [
|
mlir_tf_tools_dirs = [
|
||||||
|
@ -217,6 +217,7 @@ cc_library(
|
|||||||
"transforms/shape_inference.cc",
|
"transforms/shape_inference.cc",
|
||||||
"transforms/shape_inference_pass.cc",
|
"transforms/shape_inference_pass.cc",
|
||||||
"transforms/sink_constant.cc",
|
"transforms/sink_constant.cc",
|
||||||
|
"transforms/test_side_effect_analysis.cc",
|
||||||
"transforms/tpu_cluster_formation.cc",
|
"transforms/tpu_cluster_formation.cc",
|
||||||
"transforms/tpu_merge_variables_with_execute.cc",
|
"transforms/tpu_merge_variables_with_execute.cc",
|
||||||
"transforms/tpu_rewrite_pass.cc",
|
"transforms/tpu_rewrite_pass.cc",
|
||||||
@ -239,6 +240,7 @@ cc_library(
|
|||||||
":error_util",
|
":error_util",
|
||||||
":export_tf_dialect_op",
|
":export_tf_dialect_op",
|
||||||
":mangling_util",
|
":mangling_util",
|
||||||
|
":side_effect_analysis",
|
||||||
":tensorflow",
|
":tensorflow",
|
||||||
":tensorflow_optimize_inc_gen",
|
":tensorflow_optimize_inc_gen",
|
||||||
":tpu_rewrite_device_util",
|
":tpu_rewrite_device_util",
|
||||||
@ -467,6 +469,7 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:types",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
@ -669,6 +672,7 @@ cc_library(
|
|||||||
":import_utils",
|
":import_utils",
|
||||||
":mangling_util",
|
":mangling_util",
|
||||||
":mlir_roundtrip_flags",
|
":mlir_roundtrip_flags",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
"//tensorflow/core:lib_proto_parsing",
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
"//tensorflow/core:ops",
|
"//tensorflow/core:ops",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -981,3 +985,19 @@ cc_library(
|
|||||||
"@local_config_mlir//:Pass",
|
"@local_config_mlir//:Pass",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "side_effect_analysis",
|
||||||
|
srcs = ["analysis/side_effect_analysis.cc"],
|
||||||
|
hdrs = ["analysis/side_effect_analysis.h"],
|
||||||
|
deps = [
|
||||||
|
":tensorflow",
|
||||||
|
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@llvm//:support",
|
||||||
|
"@local_config_mlir//:IR",
|
||||||
|
"@local_config_mlir//:Pass",
|
||||||
|
"@local_config_mlir//:Support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -0,0 +1,374 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <initializer_list>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/iterator_range.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int64_t kUnknownResourceId = -1;
|
||||||
|
|
||||||
|
// Returns if a VarHandleOp is anonymous, which means it always creates a new
|
||||||
|
// variable.
|
||||||
|
bool IsResourceHandleAnonymous(TF::VarHandleOp handle) {
|
||||||
|
return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a string unique identifier for a non-anonymous VarHandleOp.
|
||||||
|
std::string GetVarHandleStringId(TF::VarHandleOp handle) {
|
||||||
|
auto device = handle.getAttrOfType<StringAttr>("device");
|
||||||
|
return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(),
|
||||||
|
"/", device ? device.getValue().str() : std::string(""));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always
|
||||||
|
// creates a new ID; otherwise, tries to reuse the existing ID for the
|
||||||
|
// referenced variable if it exists, or creates a new one if not.
|
||||||
|
int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id,
|
||||||
|
llvm::StringMap<int64_t>* name_id_map) {
|
||||||
|
// Always create a new ID for anonymous handle.
|
||||||
|
if (IsResourceHandleAnonymous(handle)) return (*next_id)++;
|
||||||
|
|
||||||
|
auto name = GetVarHandleStringId(handle);
|
||||||
|
auto emplace_res = name_id_map->try_emplace(name, *next_id);
|
||||||
|
// New ID created, increment next_id.
|
||||||
|
if (emplace_res.second) ++(*next_id);
|
||||||
|
return emplace_res.first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) {
|
||||||
|
auto func_op = llvm::dyn_cast<FuncOp>(op);
|
||||||
|
if (!func_op) return;
|
||||||
|
AnalyzeFunction(func_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
||||||
|
// This function populates resource_value_to_ids_.
|
||||||
|
//
|
||||||
|
// TODO(yuanzx): Pass variable aliasing information to functions so we can
|
||||||
|
// properly resolve aliasing arguments.
|
||||||
|
//
|
||||||
|
// Before having that, we assume function arguments do not alias each other.
|
||||||
|
int64_t next_unique_id = 0;
|
||||||
|
for (auto arg : func_op.getArguments()) {
|
||||||
|
if (!mlir::getElementTypeOrSelf(arg->getType()).isa<TF::ResourceType>())
|
||||||
|
continue;
|
||||||
|
resource_value_to_ids_[arg].insert(next_unique_id++);
|
||||||
|
}
|
||||||
|
llvm::StringMap<int64_t> var_handle_name_id_map;
|
||||||
|
auto forward_input_to_output = [&](Value* operand, Value* result) {
|
||||||
|
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
|
||||||
|
return;
|
||||||
|
auto operand_it = resource_value_to_ids_.find(operand);
|
||||||
|
assert(operand_it != resource_value_to_ids_.end() &&
|
||||||
|
"A resource-type output does not have the corresponding "
|
||||||
|
"resource-type input.");
|
||||||
|
resource_value_to_ids_[result].insert(operand_it->getSecond().begin(),
|
||||||
|
operand_it->getSecond().end());
|
||||||
|
};
|
||||||
|
// TODO(yuanzx): Consider control-flow ops.
|
||||||
|
func_op.walk([&](Operation* op) {
|
||||||
|
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
|
||||||
|
resource_value_to_ids_[var_handle.resource()].insert(
|
||||||
|
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
|
||||||
|
&var_handle_name_id_map));
|
||||||
|
} else if (llvm::isa<TF::IdentityNOp>(op) ||
|
||||||
|
llvm::isa<TF::IdentityOp>(op)) {
|
||||||
|
for (auto operand_and_result :
|
||||||
|
llvm::zip(op->getOperands(), op->getResults())) {
|
||||||
|
forward_input_to_output(std::get<0>(operand_and_result),
|
||||||
|
std::get<1>(operand_and_result));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (auto result : op->getResults()) {
|
||||||
|
if (!mlir::getElementTypeOrSelf(result->getType())
|
||||||
|
.isa<TF::ResourceType>())
|
||||||
|
continue;
|
||||||
|
resource_value_to_ids_[result].insert(kUnknownResourceId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ResourceAliasAnalysis::IsUnknownResource(const Value* resource) const {
|
||||||
|
auto it = resource_value_to_ids_.find(resource);
|
||||||
|
assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
|
||||||
|
// The set is sorted so we only need to check the first element since
|
||||||
|
// kUnknownResourceId < 0.
|
||||||
|
static_assert(kUnknownResourceId < 0,
|
||||||
|
"kUnknownResourceId should be negative");
|
||||||
|
return *it->getSecond().begin() == kUnknownResourceId;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llvm::SmallSet<int64_t, 8>& ResourceAliasAnalysis::GetResourceUniqueIds(
|
||||||
|
const Value* resource) const {
|
||||||
|
auto it = resource_value_to_ids_.find(resource);
|
||||||
|
assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
|
||||||
|
return it->getSecond();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Returns a set that contains only kUnknownResourceId.
|
||||||
|
llvm::SmallDenseSet<int64_t, 8> UnknownResourceSet() {
|
||||||
|
llvm::SmallDenseSet<int64_t, 8> unknown_set;
|
||||||
|
unknown_set.insert(kUnknownResourceId);
|
||||||
|
return unknown_set;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns all resources that could be accessed by op, or UnknownResourceSet()
|
||||||
|
// if we cannot find all of them.
|
||||||
|
llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
|
||||||
|
Operation* op, const ResourceAliasAnalysis& alias_analysis) {
|
||||||
|
llvm::SmallDenseSet<int64_t, 8> resources;
|
||||||
|
|
||||||
|
for (auto operand : op->getOperands()) {
|
||||||
|
if (!mlir::getElementTypeOrSelf(operand->getType()).isa<TF::ResourceType>())
|
||||||
|
continue;
|
||||||
|
if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet();
|
||||||
|
const auto& ids = alias_analysis.GetResourceUniqueIds(operand);
|
||||||
|
resources.insert(ids.begin(), ids.end());
|
||||||
|
}
|
||||||
|
for (auto result : op->getResults()) {
|
||||||
|
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
|
||||||
|
continue;
|
||||||
|
if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet();
|
||||||
|
const auto& ids = alias_analysis.GetResourceUniqueIds(result);
|
||||||
|
resources.insert(ids.begin(), ids.end());
|
||||||
|
}
|
||||||
|
return resources;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns an XlaResourceOpInfo (or nullptr if it does not exist) that specifies
|
||||||
|
// the resource access type of the op. It tells whether the op is read only,
|
||||||
|
// etc.
|
||||||
|
//
|
||||||
|
// TODO(yuanzx): Define this information in a different place. Currently we use
|
||||||
|
// tensorflow/compiler/tf2xla/resource_operation_table.h.
|
||||||
|
const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) {
|
||||||
|
auto op_name = op->getName().getStringRef().str();
|
||||||
|
if (op->getName().getDialect() !=
|
||||||
|
TF::TensorFlowDialect::getDialectNamespace()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensorflow::GetResourceOpInfoForOp(
|
||||||
|
op->getName().getStringRef().split('.').second.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns whether `op` accesses resources and it is known to be read-only.
|
||||||
|
bool OpIsReadOnly(Operation* op) {
|
||||||
|
auto resource_op_info = GetResourceInfoForOp(op);
|
||||||
|
return resource_op_info &&
|
||||||
|
resource_op_info->kind() == tensorflow::XlaResourceOpKind::kRead;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns if `op` is a resource declaration.
|
||||||
|
bool OpIsDeclaration(Operation* op,
|
||||||
|
const ResourceAliasAnalysis& alias_analysis) {
|
||||||
|
// TODO(yuanzx): Add other types of resources.
|
||||||
|
return llvm::isa<TF::VarHandleOp>(op) ||
|
||||||
|
((llvm::isa<TF::IdentityNOp>(op) || llvm::isa<TF::IdentityOp>(op)) &&
|
||||||
|
!FindAccessedResources(op, alias_analysis).empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op,
|
||||||
|
bool read_only) {
|
||||||
|
if (resource_id == kUnknownResourceId) {
|
||||||
|
if (read_only) {
|
||||||
|
// New unknown read is not tracked by any known resource access.
|
||||||
|
for (auto& entry : per_resource_access_info_) {
|
||||||
|
entry.getSecond().tracked_last_unknown_read = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown write can clear all other tracked information, since it acts
|
||||||
|
// like a barrier.
|
||||||
|
per_resource_access_info_.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto& info = per_resource_access_info_[resource_id];
|
||||||
|
if (read_only) {
|
||||||
|
info.reads_since_last_write.push_back(op);
|
||||||
|
// Resource read must have carried control dependencies of unknown write.
|
||||||
|
info.tracked_last_unknown_write = true;
|
||||||
|
} else {
|
||||||
|
// Resource write must have carried control dependencies of unknown access.
|
||||||
|
info.tracked_last_unknown_write = true;
|
||||||
|
info.tracked_last_unknown_read = true;
|
||||||
|
info.last_write = op;
|
||||||
|
info.reads_since_last_write.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id,
|
||||||
|
Operation* op,
|
||||||
|
bool read_only) {
|
||||||
|
auto it = per_resource_access_info_.find(resource_id);
|
||||||
|
if (it == per_resource_access_info_.end()) return;
|
||||||
|
const auto& access_info = it->getSecond();
|
||||||
|
auto& control_predecessors = control_predecessors_[op];
|
||||||
|
bool read_tracked = false;
|
||||||
|
if (!read_only) {
|
||||||
|
control_predecessors.insert(access_info.reads_since_last_write.begin(),
|
||||||
|
access_info.reads_since_last_write.end());
|
||||||
|
read_tracked = !access_info.reads_since_last_write.empty();
|
||||||
|
}
|
||||||
|
if (access_info.last_write && !read_tracked) {
|
||||||
|
control_predecessors.insert(access_info.last_write);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SideEffectAnalysis::AnalyzeFunction(
|
||||||
|
FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) {
|
||||||
|
// This function populates control_predecessors_ and control_successors_ by
|
||||||
|
// walking through func_op's body, and tracking resource accesses in
|
||||||
|
// per_resource_access_info_.
|
||||||
|
|
||||||
|
// Returns whether an access to `resource` can skip control edges from
|
||||||
|
// prevoius accesses to unknown resources, due to that earlier accesses to
|
||||||
|
// `resource` already indirectly tracked previous accesses to uknown
|
||||||
|
// resources. `read_only` specifies the type of access of the current op being
|
||||||
|
// considered.
|
||||||
|
auto unknown_access_indirectly_tracked_by_resource = [&](int64_t resource,
|
||||||
|
bool read_only) {
|
||||||
|
auto it = per_resource_access_info_.find(resource);
|
||||||
|
if (it == per_resource_access_info_.end()) return false;
|
||||||
|
auto unknown_it = per_resource_access_info_.find(kUnknownResourceId);
|
||||||
|
const bool no_unknown_read =
|
||||||
|
unknown_it == per_resource_access_info_.end() ||
|
||||||
|
unknown_it->getSecond().reads_since_last_write.empty();
|
||||||
|
return read_only
|
||||||
|
? it->second.tracked_last_unknown_write
|
||||||
|
: it->second.tracked_last_unknown_write &&
|
||||||
|
(it->second.tracked_last_unknown_read || no_unknown_read);
|
||||||
|
};
|
||||||
|
|
||||||
|
func_op.walk([&](Operation* op) {
|
||||||
|
// We do not need explicit control edges for declaration ops.
|
||||||
|
if (OpIsDeclaration(op, alias_analysis)) return;
|
||||||
|
|
||||||
|
auto resource_op_info = GetResourceInfoForOp(op);
|
||||||
|
if (!resource_op_info && op->hasNoSideEffect()) return;
|
||||||
|
|
||||||
|
llvm::SmallDenseSet<int64_t, 8> resources =
|
||||||
|
resource_op_info ? FindAccessedResources(op, alias_analysis)
|
||||||
|
: UnknownResourceSet();
|
||||||
|
assert(!resources.empty());
|
||||||
|
const bool is_unknown = resources.count(kUnknownResourceId) > 0;
|
||||||
|
const bool read_only = OpIsReadOnly(op);
|
||||||
|
bool indirectly_tracked_unknown_access = false;
|
||||||
|
// First add edges from known resources.
|
||||||
|
if (is_unknown) {
|
||||||
|
for (auto& entry : per_resource_access_info_) {
|
||||||
|
if (entry.getFirst() == kUnknownResourceId) continue;
|
||||||
|
AddPredecessorsForAccess(entry.getFirst(), op, read_only);
|
||||||
|
indirectly_tracked_unknown_access |=
|
||||||
|
unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
|
||||||
|
read_only);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int64_t resource : resources) {
|
||||||
|
AddPredecessorsForAccess(resource, op, read_only);
|
||||||
|
indirectly_tracked_unknown_access |=
|
||||||
|
unknown_access_indirectly_tracked_by_resource(resource, read_only);
|
||||||
|
// Update access info for known resources.
|
||||||
|
TrackAccess(resource, op, read_only);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If not indirectly tracked, add edges from the unknown resource.
|
||||||
|
if (!indirectly_tracked_unknown_access) {
|
||||||
|
AddPredecessorsForAccess(kUnknownResourceId, op, read_only);
|
||||||
|
}
|
||||||
|
if (is_unknown) {
|
||||||
|
// Update access info for unknown resource.
|
||||||
|
TrackAccess(kUnknownResourceId, op, read_only);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Populate control_successors_ based on control_predecessors_.
|
||||||
|
for (auto& entry : control_predecessors_) {
|
||||||
|
auto op = entry.getFirst();
|
||||||
|
for (auto predecessor : entry.getSecond()) {
|
||||||
|
control_successors_[predecessor].insert(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Operation*, 8> SideEffectAnalysis::DirectControlPredecessors(
|
||||||
|
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
|
||||||
|
llvm::SmallVector<Operation*, 8> result;
|
||||||
|
auto it = control_predecessors_.find(op);
|
||||||
|
if (it == control_predecessors_.end()) return result;
|
||||||
|
result.reserve(it->getSecond().size());
|
||||||
|
for (auto predecessor : it->getSecond()) {
|
||||||
|
if (!filter || filter(predecessor)) result.push_back(predecessor);
|
||||||
|
}
|
||||||
|
llvm::sort(result,
|
||||||
|
[](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Operation*, 8> SideEffectAnalysis::DirectControlSuccessors(
|
||||||
|
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
|
||||||
|
llvm::SmallVector<Operation*, 8> result;
|
||||||
|
auto it = control_successors_.find(op);
|
||||||
|
if (it == control_successors_.end()) return result;
|
||||||
|
result.reserve(it->getSecond().size());
|
||||||
|
for (auto successor : it->getSecond()) {
|
||||||
|
if (!filter || filter(successor)) result.push_back(successor);
|
||||||
|
}
|
||||||
|
llvm::sort(result,
|
||||||
|
[](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
SideEffectAnalysis::SideEffectAnalysis(Operation* op) {
|
||||||
|
auto func_op = llvm::dyn_cast<FuncOp>(op);
|
||||||
|
if (!func_op) return;
|
||||||
|
ResourceAliasAnalysis alias_analysis(op);
|
||||||
|
AnalyzeFunction(func_op, alias_analysis);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace TF
|
||||||
|
} // namespace mlir
|
@ -0,0 +1,125 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringMap.h"
|
||||||
|
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Region.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
|
||||||
|
// An analysis that runs on a function and maps each resource-type value to a
|
||||||
|
// set of unique int64_t IDs representing the possible resources it could alias.
|
||||||
|
class ResourceAliasAnalysis {
|
||||||
|
public:
|
||||||
|
explicit ResourceAliasAnalysis(Operation* op);
|
||||||
|
~ResourceAliasAnalysis() = default;
|
||||||
|
ResourceAliasAnalysis(ResourceAliasAnalysis&&) = default;
|
||||||
|
|
||||||
|
// Returns if the analysis fails to resolve a resource-type value.
|
||||||
|
bool IsUnknownResource(const Value* resource) const;
|
||||||
|
|
||||||
|
// Returns the set unique IDs which `resource` could alias. Requires that
|
||||||
|
// IsUnknownResource(resource) == true.
|
||||||
|
const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds(
|
||||||
|
const Value* resource) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
ResourceAliasAnalysis() = default;
|
||||||
|
|
||||||
|
// Runs the analysis on `func_op` and populates resource_value_to_ids_.
|
||||||
|
void AnalyzeFunction(FuncOp func_op);
|
||||||
|
|
||||||
|
// Maps each resource-type value to a set of unique IDs that it could alias.
|
||||||
|
llvm::SmallDenseMap<const Value*, llvm::SmallSet<int64_t, 8>, 8>
|
||||||
|
resource_value_to_ids_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An analysis that runs on a function and infers the control predecessors and
|
||||||
|
// successors for each op, based on side-effects on known and unknown resources.
|
||||||
|
// Side-effecting ops on uknown resources are conservatively treated as
|
||||||
|
// interfering with all known resource op accesses. It distinguishes accesses
|
||||||
|
// based on whether they are read-only, and read-only ops do not interfer with
|
||||||
|
// each other.
|
||||||
|
class SideEffectAnalysis {
|
||||||
|
public:
|
||||||
|
explicit SideEffectAnalysis(Operation* op);
|
||||||
|
SideEffectAnalysis(SideEffectAnalysis&& other) = default;
|
||||||
|
~SideEffectAnalysis() = default;
|
||||||
|
|
||||||
|
// Returns a vector of ops that are direct control predecessors of `op`,
|
||||||
|
// sorted in program order. If `filter` is provided, only predecessors that
|
||||||
|
// pass the filter (returning true) will be included.
|
||||||
|
llvm::SmallVector<Operation*, 8> DirectControlPredecessors(
|
||||||
|
Operation* op,
|
||||||
|
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
|
||||||
|
|
||||||
|
// Returns a vector of ops that are direct control successors of `op`, sorted
|
||||||
|
// in program order. If `filter` is provided, only successors that pass the
|
||||||
|
// filter (returning true) will be included.
|
||||||
|
llvm::SmallVector<Operation*, 8> DirectControlSuccessors(
|
||||||
|
Operation* op,
|
||||||
|
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Runs the analysis on `func_op` and populates control_predecessors_ and
|
||||||
|
// control_successors_.
|
||||||
|
void AnalyzeFunction(FuncOp func_op,
|
||||||
|
const ResourceAliasAnalysis& alias_analysis);
|
||||||
|
|
||||||
|
// Updates control_predecessors_ for `op` that is being visted, on the given
|
||||||
|
// `resource_id`.
|
||||||
|
void AddPredecessorsForAccess(int64_t resource_id, Operation* op,
|
||||||
|
bool read_only);
|
||||||
|
|
||||||
|
// Adds op's access to per_resource_access_info_.
|
||||||
|
void TrackAccess(int64_t resource_id, Operation* op, bool read_only);
|
||||||
|
|
||||||
|
// Maps from an op to its control predecessors.
|
||||||
|
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 8>, 8>
|
||||||
|
control_predecessors_;
|
||||||
|
// Maps from an op to its control successors.
|
||||||
|
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 8>, 8>
|
||||||
|
control_successors_;
|
||||||
|
|
||||||
|
// Internal per-resource data structure when we build the dependencies.
|
||||||
|
struct PerResourceAcessInfo {
|
||||||
|
// Last op that writes the resource before the current op being analyzed.
|
||||||
|
Operation* last_write = nullptr;
|
||||||
|
// Read ops since last_write before the current op being analyzed.
|
||||||
|
llvm::SmallVector<Operation*, 8> reads_since_last_write;
|
||||||
|
// Whether previous accesses of this resource already tracked last unknown
|
||||||
|
// read/write.
|
||||||
|
bool tracked_last_unknown_read = false;
|
||||||
|
bool tracked_last_unknown_write = false;
|
||||||
|
};
|
||||||
|
llvm::SmallDenseMap<int64_t, PerResourceAcessInfo, 8>
|
||||||
|
per_resource_access_info_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace TF
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
|
@ -26,7 +26,7 @@ static DialectRegistration<TFControlFlow::TFControlFlowDialect>
|
|||||||
tf_control_flow_ops;
|
tf_control_flow_ops;
|
||||||
static DialectRegistration<TF::TensorFlowDialect> tf_ops;
|
static DialectRegistration<TF::TensorFlowDialect> tf_ops;
|
||||||
static DialectRegistration<tf_executor::TensorFlowExecutorDialect>
|
static DialectRegistration<tf_executor::TensorFlowExecutorDialect>
|
||||||
tf_excutor_dialect;
|
tf_executor_dialect;
|
||||||
static DialectRegistration<tf_device::TensorFlowDeviceDialect>
|
static DialectRegistration<tf_device::TensorFlowDeviceDialect>
|
||||||
tf_device_dialect;
|
tf_device_dialect;
|
||||||
static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect>
|
static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect>
|
||||||
|
@ -26,11 +26,13 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/SMLoc.h"
|
#include "llvm/Support/SMLoc.h"
|
||||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
|
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||||
|
@ -576,6 +576,44 @@ endian orderings will give different results.
|
|||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_BitwiseOrOp : TF_Op<"BitwiseOr", [Broadcastable, Commutative, NoSideEffect]>,
|
||||||
|
WithBroadcastableBinOpBuilder {
|
||||||
|
let summary = "Elementwise computes the bitwise OR of `x` and `y`.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The result will have those bits set, that are set in `x`, `y` or both. The
|
||||||
|
computation is performed on the underlying representations of `x` and `y`.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.ops import bitwise_ops
|
||||||
|
dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64,
|
||||||
|
tf.uint8, tf.uint16, tf.uint32, tf.uint64]
|
||||||
|
|
||||||
|
for dtype in dtype_list:
|
||||||
|
lhs = tf.constant([0, 5, 3, 14], dtype=dtype)
|
||||||
|
rhs = tf.constant([5, 0, 7, 11], dtype=dtype)
|
||||||
|
exp = tf.constant([5, 5, 7, 15], dtype=tf.float32)
|
||||||
|
|
||||||
|
res = bitwise_ops.bitwise_or(lhs, rhs)
|
||||||
|
tf.assert_equal(tf.cast(res, tf.float32), exp) # TRUE
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_IntTensor:$x,
|
||||||
|
TF_IntTensor:$y
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_IntTensor:$z
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect]> {
|
def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Return the reduction indices for computing gradients of s0 op s1 with broadcast.
|
Return the reduction indices for computing gradients of s0 op s1 with broadcast.
|
||||||
@ -1320,6 +1358,10 @@ Comparison with `numpy.einsum`:
|
|||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||||
|
|
||||||
|
let verifier = [{
|
||||||
|
return Verify(*this);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_EluOp : TF_Op<"Elu", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_EluOp : TF_Op<"Elu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||||
@ -5726,6 +5768,8 @@ If two elements are equal, the lower-index element appears first.
|
|||||||
);
|
);
|
||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
|
||||||
|
let verifier = [{ return Verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
|
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
|
||||||
|
@ -39,7 +39,7 @@ This dialect maps to TensorFlow operations.
|
|||||||
Invariants:
|
Invariants:
|
||||||
|
|
||||||
* All values are of Tensor type (in particular, scalars are
|
* All values are of Tensor type (in particular, scalars are
|
||||||
represented using zero-dimentional tensors);
|
represented using zero-dimensional tensors);
|
||||||
|
|
||||||
TODO: Make invariants more structured so that we can reference them in ops.
|
TODO: Make invariants more structured so that we can reference them in ops.
|
||||||
}];
|
}];
|
||||||
|
@ -513,7 +513,7 @@ void ConstOp::build(Builder *builder, OperationState &result, Attribute value) {
|
|||||||
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
|
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
|
||||||
value.isa<IntegerAttr>()) {
|
value.isa<IntegerAttr>()) {
|
||||||
// All TensorFlow types must be tensor types. In the build() method,
|
// All TensorFlow types must be tensor types. In the build() method,
|
||||||
// we want to provide more flexiblity by allowing attributes of scalar
|
// we want to provide more flexibility by allowing attributes of scalar
|
||||||
// types. But we need to wrap it up with ElementsAttr to construct
|
// types. But we need to wrap it up with ElementsAttr to construct
|
||||||
// valid TensorFlow constants.
|
// valid TensorFlow constants.
|
||||||
type = RankedTensorType::get(/*shape=*/{}, value.getType());
|
type = RankedTensorType::get(/*shape=*/{}, value.getType());
|
||||||
@ -674,6 +674,21 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
results.insert<DivWithSqrtDivisor>(context);
|
results.insert<DivWithSqrtDivisor>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// EinsumOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Verifies that,
|
||||||
|
// * Arity of the op is at most two.
|
||||||
|
//
|
||||||
|
// TODO(hinsu): Verify einsum equation attribute.
|
||||||
|
static LogicalResult Verify(EinsumOp op) {
|
||||||
|
if (op.N() > 2) {
|
||||||
|
return op.emitOpError("supports at most two operands");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// EmptyTensorListOp
|
// EmptyTensorListOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1683,6 +1698,21 @@ static LogicalResult Verify(TensorListStackOp op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TopKV2Op
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult Verify(TopKV2Op op) {
|
||||||
|
if (!HasRankAtLeast(op.input(), 1))
|
||||||
|
return op.emitOpError(
|
||||||
|
"requires input operand to have at least 1 dimension");
|
||||||
|
|
||||||
|
if (!IsOfRankOrUnranked(op.k(), 0))
|
||||||
|
return op.emitOpError("requires k operand to be 0D tensor");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TransposeOp
|
// TransposeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -41,7 +41,7 @@ class TensorFlowDialect : public Dialect {
|
|||||||
|
|
||||||
static StringRef getDialectNamespace() { return "tf"; }
|
static StringRef getDialectNamespace() { return "tf"; }
|
||||||
|
|
||||||
// Gradient attribute ("tf.gradient") in the list of NamedAttibutes in a
|
// Gradient attribute ("tf.gradient") in the list of NamedAttributes in a
|
||||||
// function references to its gradient function. This attribute in TensorFlow
|
// function references to its gradient function. This attribute in TensorFlow
|
||||||
// Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
|
// Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
|
||||||
// string description of gradient attribute.
|
// string description of gradient attribute.
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||||
|
@ -101,6 +101,17 @@ func @testDifferentCastType(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64
|
|||||||
// CHECK: return %0, %1
|
// CHECK: return %0, %1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testCompatibleCastType
|
||||||
|
func @testCompatibleCastType(%arg0: tensor<?xf32>) -> (tensor<10xf32>, tensor<10xf32>) {
|
||||||
|
%0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<?xf32>) -> tensor<10xf32>
|
||||||
|
%1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<?xf32>) -> tensor<10xf32>
|
||||||
|
return %0, %1: tensor<10xf32>, tensor<10xf32>
|
||||||
|
|
||||||
|
// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<?xf32>) -> tensor<10xf32>
|
||||||
|
// CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<?xf32>) -> tensor<10xf32>
|
||||||
|
// CHECK: return %0, %1
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: testSameCastTypeAcrossBasicBlocks
|
// CHECK-LABEL: testSameCastTypeAcrossBasicBlocks
|
||||||
func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> {
|
func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> {
|
||||||
^bb0(%arg0: tensor<8x16x32x64xf32>):
|
^bb0(%arg0: tensor<8x16x32x64xf32>):
|
||||||
|
@ -232,12 +232,12 @@ module {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Single device with non-continous instructions in original block.
|
// Single device with non-continuous instructions in original block.
|
||||||
|
|
||||||
module {
|
module {
|
||||||
// CHECK-LABEL: func @noncontinoussinglecluster
|
// CHECK-LABEL: func @noncontinuoussinglecluster
|
||||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
|
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
|
||||||
func @noncontinoussinglecluster(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
func @noncontinuoussinglecluster(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||||
%0 = tf_executor.graph {
|
%0 = tf_executor.graph {
|
||||||
%1:2 = tf_executor.island {
|
%1:2 = tf_executor.island {
|
||||||
|
|
||||||
|
@ -0,0 +1,237 @@
|
|||||||
|
// RUN: tf-opt -split-input-file -tf-test-side-effect-analysis -verify-diagnostics %s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// Tests that the pass tracks control dependencies for reads/writes on the same
|
||||||
|
// resource.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @non_aliasing_reads_writes
|
||||||
|
func @non_aliasing_reads_writes(
|
||||||
|
// expected-remark@above {{ID: 13}}
|
||||||
|
// expected-remark@above {{Predecessors: {12}}}
|
||||||
|
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||||
|
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||||
|
%arg2: tensor<32xf32>) -> (tensor<32xf32>) {
|
||||||
|
%graph = tf_executor.graph {
|
||||||
|
// expected-remark@above {{ID: 11}}
|
||||||
|
// expected-remark@above {{Predecessors: {10}}}
|
||||||
|
// expected-remark@above {{Successors: {12}}}
|
||||||
|
// CHECK: tf_executor.island
|
||||||
|
%island:2 = tf_executor.island {
|
||||||
|
// expected-remark@above {{ID: 9}}
|
||||||
|
// expected-remark@above {{Predecessors: {8}}}
|
||||||
|
// expected-remark@above {{Successors: {10}}}
|
||||||
|
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 0}}
|
||||||
|
// expected-remark@above {{Successors: {1}}}
|
||||||
|
"tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||||
|
// expected-remark@above {{ID: 1}}
|
||||||
|
// expected-remark@above {{Predecessors: {0}}}
|
||||||
|
// expected-remark@above {{Successors: {6}}}
|
||||||
|
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 2}}
|
||||||
|
// expected-remark@above {{Successors: {5}}}
|
||||||
|
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||||
|
// expected-remark@above {{ID: 3}}
|
||||||
|
%read2 = "tf.ReadVariableOp"(%var_handle) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 4}}
|
||||||
|
// expected-remark@above {{Successors: {8}}}
|
||||||
|
"tf.AssignVariableOp"(%arg1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||||
|
// expected-remark@above {{ID: 5}}
|
||||||
|
// expected-remark@above {{Predecessors: {2}}}
|
||||||
|
// expected-remark@above {{Successors: {8}}}
|
||||||
|
"tf.AssignVariableOp"(%arg0, %read2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||||
|
// expected-remark@above {{ID: 6}}
|
||||||
|
// expected-remark@above {{Predecessors: {1}}}
|
||||||
|
// expected-remark@above {{Successors: {7}}}
|
||||||
|
%read3 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 7}}
|
||||||
|
// expected-remark@above {{Predecessors: {6}}}
|
||||||
|
// expected-remark@above {{Successors: {8}}}
|
||||||
|
tf_executor.yield %read3 : tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 8}}
|
||||||
|
// expected-remark@above {{Predecessors: {4,5,7}}}
|
||||||
|
// expected-remark@above {{Successors: {9}}}
|
||||||
|
}
|
||||||
|
tf_executor.fetch %island#0 : tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 10}}
|
||||||
|
// expected-remark@above {{Predecessors: {9}}}
|
||||||
|
// expected-remark@above {{Successors: {11}}}
|
||||||
|
}
|
||||||
|
return %graph : tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 12}}
|
||||||
|
// expected-remark@above {{Predecessors: {11}}}
|
||||||
|
// expected-remark@above {{Successors: {13}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Tests that the pass tracks control dependencies for reads/writes on the two
|
||||||
|
// resource handles that refer to the same variable.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @aliasing_reads_writes
|
||||||
|
func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () {
|
||||||
|
// expected-remark@above {{ID: 14}}
|
||||||
|
// expected-remark@above {{Predecessors: {13}}}
|
||||||
|
tf_executor.graph {
|
||||||
|
// expected-remark@above {{ID: 12}}
|
||||||
|
// expected-remark@above {{Predecessors: {11}}}
|
||||||
|
// expected-remark@above {{Successors: {13}}}
|
||||||
|
// CHECK: tf_executor.island
|
||||||
|
%island = tf_executor.island {
|
||||||
|
// expected-remark@above {{ID: 10}}
|
||||||
|
// expected-remark@above {{Predecessors: {9}}}
|
||||||
|
// expected-remark@above {{Successors: {11}}}
|
||||||
|
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||||
|
// expected-remark@above {{ID: 0}}
|
||||||
|
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||||
|
// expected-remark@above {{ID: 1}}
|
||||||
|
%vh1_id:2 = "tf.IdentityN"(%vh1, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>)
|
||||||
|
// expected-remark@above {{ID: 2}}
|
||||||
|
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 3}}
|
||||||
|
// expected-remark@above {{Successors: {4}}}
|
||||||
|
"tf.AssignVariableOp"(%vh1_id#0, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||||
|
// expected-remark@above {{ID: 4}}
|
||||||
|
// expected-remark@above {{Predecessors: {3}}}
|
||||||
|
// expected-remark@above {{Successors: {5,6}}}
|
||||||
|
%read1 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 5}}
|
||||||
|
// expected-remark@above {{Predecessors: {4}}}
|
||||||
|
// expected-remark@above {{Successors: {7}}}
|
||||||
|
%read2 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 6}}
|
||||||
|
// expected-remark@above {{Predecessors: {4}}}
|
||||||
|
// expected-remark@above {{Successors: {7}}}
|
||||||
|
"tf.AssignVariableOp"(%vh0, %read2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||||
|
// expected-remark@above {{ID: 7}}
|
||||||
|
// expected-remark@above {{Predecessors: {5,6}}}
|
||||||
|
// expected-remark@above {{Successors: {8}}}
|
||||||
|
"tf.AssignVariableOp"(%vh1_id#0, %read1) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||||
|
// expected-remark@above {{ID: 8}}
|
||||||
|
// expected-remark@above {{Predecessors: {7}}}
|
||||||
|
// expected-remark@above {{Successors: {9}}}
|
||||||
|
tf_executor.yield
|
||||||
|
// expected-remark@above {{ID: 9}}
|
||||||
|
// expected-remark@above {{Predecessors: {8}}}
|
||||||
|
// expected-remark@above {{Successors: {10}}}
|
||||||
|
}
|
||||||
|
tf_executor.fetch %island : !tf_executor.control
|
||||||
|
// expected-remark@above {{ID: 11}}
|
||||||
|
// expected-remark@above {{Predecessors: {10}}}
|
||||||
|
// expected-remark@above {{Successors: {12}}}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
// expected-remark@above {{ID: 13}}
|
||||||
|
// expected-remark@above {{Predecessors: {12}}}
|
||||||
|
// expected-remark@above {{Successors: {14}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Tests that the pass tracks control dependencies for side-effecting on unknown
|
||||||
|
// resources.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @unknown_side_effecting_op
|
||||||
|
func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
|
||||||
|
// expected-remark@above {{ID: 13}}
|
||||||
|
// expected-remark@above {{Predecessors: {12}}}
|
||||||
|
tf_executor.graph {
|
||||||
|
// expected-remark@above {{ID: 11}}
|
||||||
|
// expected-remark@above {{Predecessors: {10}}}
|
||||||
|
// expected-remark@above {{Successors: {12}}}
|
||||||
|
// CHECK: tf_executor.island
|
||||||
|
%island = tf_executor.island {
|
||||||
|
// expected-remark@above {{ID: 9}}
|
||||||
|
// expected-remark@above {{Predecessors: {8}}}
|
||||||
|
// expected-remark@above {{Successors: {10}}}
|
||||||
|
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||||
|
// expected-remark@above {{ID: 0}}
|
||||||
|
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||||
|
// expected-remark@above {{ID: 1}}
|
||||||
|
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 2}}
|
||||||
|
// expected-remark@above {{Successors: {4}}}
|
||||||
|
"tf.AssignVariableOp"(%vh1, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||||
|
// expected-remark@above {{ID: 3}}
|
||||||
|
// expected-remark@above {{Successors: {4}}}
|
||||||
|
"tf._UnknownSideEffectingOp_"() : () -> ()
|
||||||
|
// expected-remark@above {{ID: 4}}
|
||||||
|
// expected-remark@above {{Predecessors: {2,3}}}
|
||||||
|
// expected-remark@above {{Successors: {5,6}}}
|
||||||
|
%read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 5}}
|
||||||
|
// expected-remark@above {{Predecessors: {4}}}
|
||||||
|
// expected-remark@above {{Successors: {7}}}
|
||||||
|
"tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||||
|
// expected-remark@above {{ID: 6}}
|
||||||
|
// expected-remark@above {{Predecessors: {4}}}
|
||||||
|
// expected-remark@above {{Successors: {8}}}
|
||||||
|
"tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||||
|
// expected-remark@above {{ID: 7}}
|
||||||
|
// expected-remark@above {{Predecessors: {5}}}
|
||||||
|
// expected-remark@above {{Successors: {8}}}
|
||||||
|
tf_executor.yield
|
||||||
|
// expected-remark@above {{ID: 8}}
|
||||||
|
// expected-remark@above {{Predecessors: {6,7}}}
|
||||||
|
// expected-remark@above {{Successors: {9}}}
|
||||||
|
}
|
||||||
|
tf_executor.fetch %island : !tf_executor.control
|
||||||
|
// expected-remark@above {{ID: 10}}
|
||||||
|
// expected-remark@above {{Predecessors: {9}}}
|
||||||
|
// expected-remark@above {{Successors: {11}}}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
// expected-remark@above {{ID: 12}}
|
||||||
|
// expected-remark@above {{Predecessors: {11}}}
|
||||||
|
// expected-remark@above {{Successors: {13}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Tests that the pass tracks control dependencies for read-only ops on unknown
|
||||||
|
// resources.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @read_only_unknown_resource
|
||||||
|
func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () {
|
||||||
|
// expected-remark@above {{ID: 10}}
|
||||||
|
// expected-remark@above {{Predecessors: {9}}}
|
||||||
|
tf_executor.graph {
|
||||||
|
// expected-remark@above {{ID: 8}}
|
||||||
|
// expected-remark@above {{Predecessors: {7}}}
|
||||||
|
// expected-remark@above {{Successors: {9}}}
|
||||||
|
// CHECK: tf_executor.island
|
||||||
|
%island = tf_executor.island {
|
||||||
|
// expected-remark@above {{ID: 6}}
|
||||||
|
// expected-remark@above {{Predecessors: {5}}}
|
||||||
|
// expected-remark@above {{Successors: {7}}}
|
||||||
|
%vh0 = "tf._UnknownSideEffectingOp_"() : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||||
|
// expected-remark@above {{ID: 0}}
|
||||||
|
// expected-remark@above {{Successors: {2,3}}}
|
||||||
|
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||||
|
// expected-remark@above {{ID: 1}}
|
||||||
|
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 2}}
|
||||||
|
// expected-remark@above {{Predecessors: {0}}}
|
||||||
|
// expected-remark@above {{Successors: {4}}}
|
||||||
|
%read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||||
|
// expected-remark@above {{ID: 3}}
|
||||||
|
// expected-remark@above {{Predecessors: {0}}}
|
||||||
|
// expected-remark@above {{Successors: {4}}}
|
||||||
|
"tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||||
|
// expected-remark@above {{ID: 4}}
|
||||||
|
// expected-remark@above {{Predecessors: {2,3}}}
|
||||||
|
// expected-remark@above {{Successors: {5}}}
|
||||||
|
tf_executor.yield
|
||||||
|
// expected-remark@above {{ID: 5}}
|
||||||
|
// expected-remark@above {{Predecessors: {4}}}
|
||||||
|
// expected-remark@above {{Successors: {6}}}
|
||||||
|
}
|
||||||
|
tf_executor.fetch %island : !tf_executor.control
|
||||||
|
// expected-remark@above {{ID: 7}}
|
||||||
|
// expected-remark@above {{Predecessors: {6}}}
|
||||||
|
// expected-remark@above {{Successors: {8}}}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
// expected-remark@above {{ID: 9}}
|
||||||
|
// expected-remark@above {{Predecessors: {8}}}
|
||||||
|
// expected-remark@above {{Successors: {10}}}
|
||||||
|
}
|
@ -1650,3 +1650,27 @@ func @testSplitSmallSplitDim(%input: tensor<4x8xf32>) {
|
|||||||
%0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x8xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
%0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x8xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @testTernaryEinsum(%arg0: tensor<2x3xf32>){
|
||||||
|
// expected-error @+1 {{supports at most two operands}}
|
||||||
|
%0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @testTopKV2WrongInputRank(%input: tensor<f32>, %k: tensor<i32>) {
|
||||||
|
// expected-error @+1 {{op requires input operand to have at least 1 dimension}}
|
||||||
|
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<f32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) {
|
||||||
|
// expected-error @+1 {{op requires k operand to be 0D tensor}}
|
||||||
|
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -523,7 +523,7 @@ func @invalid_merge(%arg0: tensor<*x!tf.resource>, %arg1: tensor<4x!tf.resource>
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Check that if result is a ref type, all operands need to be ref too.
|
// Check that if result is a ref type, all operands need to be ref too.
|
||||||
func @inavlid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> {
|
func @invalid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> {
|
||||||
%result = tf_executor.graph {
|
%result = tf_executor.graph {
|
||||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
|
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
|
||||||
// expected-error@-1 {{'tf_executor.Merge' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}
|
// expected-error@-1 {{'tf_executor.Merge' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}
|
||||||
|
@ -68,14 +68,17 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) {
|
|||||||
namespace TF {
|
namespace TF {
|
||||||
|
|
||||||
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
|
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
|
||||||
bool enable_logging) {
|
bool enable_logging,
|
||||||
|
bool enable_inliner) {
|
||||||
PassManager bridge(module.getContext());
|
PassManager bridge(module.getContext());
|
||||||
|
|
||||||
// Add logger to bridge passmanager.
|
// Add logger to bridge passmanager.
|
||||||
if (enable_logging)
|
if (enable_logging)
|
||||||
bridge.addInstrumentation(std::make_unique<tensorflow::BridgeLogger>());
|
bridge.addInstrumentation(std::make_unique<tensorflow::BridgeLogger>());
|
||||||
|
|
||||||
CreateTFStandardPipeline(bridge);
|
StandardPipelineOptions pipeline_options;
|
||||||
|
pipeline_options.enable_inliner.setValue(enable_inliner);
|
||||||
|
CreateTFStandardPipeline(bridge, pipeline_options);
|
||||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||||
LogicalResult result = bridge.run(module);
|
LogicalResult result = bridge.run(module);
|
||||||
(void)result;
|
(void)result;
|
||||||
|
@ -31,11 +31,13 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging);
|
|||||||
|
|
||||||
namespace TF {
|
namespace TF {
|
||||||
|
|
||||||
// Run all passes involved in transforming or optimizing an MLIR graph without
|
// Runs all passes involved in transforming or optimizing an MLIR graph without
|
||||||
// any target specialization. When enable_logging is true, enables
|
// any target specialization. When enable_logging is true, enables
|
||||||
// tensorflow::BridgeLogger.
|
// tensorflow::BridgeLogger. When enable_inliner is true, enables the inliner
|
||||||
|
// pass.
|
||||||
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
|
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
|
||||||
bool enable_logging);
|
bool enable_logging,
|
||||||
|
bool enable_inliner);
|
||||||
|
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|
||||||
|
@ -29,10 +29,4 @@ mlir::PassPipelineRegistration<> tpu_pipeline(
|
|||||||
"that it is suitable for targeting TPUs.",
|
"that it is suitable for targeting TPUs.",
|
||||||
mlir::TFTPU::CreateTPUBridge);
|
mlir::TFTPU::CreateTPUBridge);
|
||||||
|
|
||||||
mlir::PassPipelineRegistration<> standard_pipeline(
|
|
||||||
"tf-standard-bridge",
|
|
||||||
"Run all passes involved in transforming or optimizing an MLIR graph"
|
|
||||||
"without any target specialization.",
|
|
||||||
mlir::TF::CreateTFStandardPipeline);
|
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
@ -22,6 +22,9 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
|||||||
def SingleResultAndOperandHaveSameElementType : Constraint<
|
def SingleResultAndOperandHaveSameElementType : Constraint<
|
||||||
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
|
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
|
||||||
|
|
||||||
|
def SingleResultAndOperandHaveSameType : Constraint<
|
||||||
|
CPred<"$0->getType() == $1->getType()">>;
|
||||||
|
|
||||||
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;
|
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -75,8 +78,7 @@ def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
|
|||||||
|
|
||||||
def CastSameType : Pat<(TF_CastOp:$res $arg, $truncate),
|
def CastSameType : Pat<(TF_CastOp:$res $arg, $truncate),
|
||||||
(replaceWithValue $arg),
|
(replaceWithValue $arg),
|
||||||
[(SingleResultAndOperandHaveSameElementType $res,
|
[(SingleResultAndOperandHaveSameType $res, $arg)]>;
|
||||||
$arg)]>;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Conj op patterns.
|
// Conj op patterns.
|
||||||
|
@ -45,7 +45,8 @@ struct TFOptimizePass : public FunctionPass<TFOptimizePass> {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
|
// NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
|
||||||
void CreateTFStandardPipeline(OpPassManager &pm) {
|
void CreateTFStandardPipeline(OpPassManager &pm,
|
||||||
|
const StandardPipelineOptions &options) {
|
||||||
OpPassManager &func_pm = pm.nest<FuncOp>();
|
OpPassManager &func_pm = pm.nest<FuncOp>();
|
||||||
|
|
||||||
// First operates on the executor dialect:
|
// First operates on the executor dialect:
|
||||||
@ -59,8 +60,12 @@ void CreateTFStandardPipeline(OpPassManager &pm) {
|
|||||||
// Hopefully there is a single island left, or there wasn't any to begin with.
|
// Hopefully there is a single island left, or there wasn't any to begin with.
|
||||||
// We now run the optimizer which operates mostly inside islands.
|
// We now run the optimizer which operates mostly inside islands.
|
||||||
func_pm.addPass(createCanonicalizerPass());
|
func_pm.addPass(createCanonicalizerPass());
|
||||||
func_pm.addPass(CreateTFOptimizePass());
|
if (options.enable_inliner) {
|
||||||
func_pm.addPass(createCSEPass());
|
pm.addPass(createInlinerPass());
|
||||||
|
}
|
||||||
|
pm.addNestedPass<FuncOp>(CreateTFShapeInferencePass());
|
||||||
|
pm.addNestedPass<FuncOp>(CreateTFOptimizePass());
|
||||||
|
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass() {
|
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass() {
|
||||||
@ -70,7 +75,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass() {
|
|||||||
static PassRegistration<TFOptimizePass> pass("tf-optimize", "Optimizes TF.");
|
static PassRegistration<TFOptimizePass> pass("tf-optimize", "Optimizes TF.");
|
||||||
|
|
||||||
// Registers a pipeline builder function for the default canonicalize/optimizer.
|
// Registers a pipeline builder function for the default canonicalize/optimizer.
|
||||||
static mlir::PassPipelineRegistration<> pipeline(
|
static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
|
||||||
"tf-standard-pipeline",
|
"tf-standard-pipeline",
|
||||||
"Run all the passes involved in transforming/optimizing the graph after "
|
"Run all the passes involved in transforming/optimizing the graph after "
|
||||||
"importing into MLIR, without any target specialization.",
|
"importing into MLIR, without any target specialization.",
|
||||||
|
@ -46,10 +46,17 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateTFShapeInferencePass();
|
|||||||
// Optimizes Tensorflow graph.
|
// Optimizes Tensorflow graph.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass();
|
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass();
|
||||||
|
|
||||||
|
struct StandardPipelineOptions : public PassOptions<StandardPipelineOptions> {
|
||||||
|
Option<bool> enable_inliner{*this, "enable-inliner",
|
||||||
|
llvm::cl::desc("Enable inliner."),
|
||||||
|
llvm::cl::init(false)};
|
||||||
|
};
|
||||||
|
|
||||||
// Propagates the pass manager with the passes involved in transforming or
|
// Propagates the pass manager with the passes involved in transforming or
|
||||||
// optimizing an MLIR graph without any target specialization.
|
// optimizing an MLIR graph without any target specialization.
|
||||||
// NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
|
// NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
|
||||||
void CreateTFStandardPipeline(OpPassManager& pm);
|
void CreateTFStandardPipeline(OpPassManager& pm,
|
||||||
|
const StandardPipelineOptions& options);
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|
||||||
namespace TFControlFlow {
|
namespace TFControlFlow {
|
||||||
|
@ -0,0 +1,77 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tf_executor {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// A pass that adds "Predecessors" and "Successors" remarks for each op based on
|
||||||
|
// SideEffectAnalysis result. For testing purpose only.
|
||||||
|
struct TestSideEffectAnalysis
|
||||||
|
: public mlir::FunctionPass<TestSideEffectAnalysis> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
int64_t next_id = 0;
|
||||||
|
llvm::SmallDenseMap<Operation*, int64_t, 8> ids;
|
||||||
|
getFunction().walk([&](Operation* op) {
|
||||||
|
ids[op] = next_id++;
|
||||||
|
op->emitRemark("ID: ") << ids[op];
|
||||||
|
});
|
||||||
|
auto join_ids = [&](const llvm::ArrayRef<Operation*> ops) {
|
||||||
|
llvm::SmallVector<std::string, 8> id_vec;
|
||||||
|
id_vec.reserve(ops.size());
|
||||||
|
for (auto op : ops) id_vec.push_back(std::to_string(ids[op]));
|
||||||
|
return llvm::join(id_vec, ",");
|
||||||
|
};
|
||||||
|
auto& analysis = getAnalysis<TF::SideEffectAnalysis>();
|
||||||
|
getFunction().walk([&](Operation* op) {
|
||||||
|
if (!analysis.DirectControlPredecessors(op).empty()) {
|
||||||
|
op->emitRemark("Predecessors: ")
|
||||||
|
<< "{" << join_ids(analysis.DirectControlPredecessors(op)) << "}";
|
||||||
|
}
|
||||||
|
if (!analysis.DirectControlSuccessors(op).empty()) {
|
||||||
|
op->emitRemark("Successors: ")
|
||||||
|
<< "{" << join_ids(analysis.DirectControlSuccessors(op)) << "}";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static mlir::PassRegistration<TestSideEffectAnalysis> pass(
|
||||||
|
"tf-test-side-effect-analysis",
|
||||||
|
"Add remarks based on side-effect analysis result, for testing purpose.");
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
} // namespace tf_executor
|
||||||
|
} // namespace mlir
|
@ -141,7 +141,7 @@ static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
|
|||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static llvm::cl::list<std::string> cl_pass_list(
|
static llvm::cl::list<std::string> cl_pass_list(
|
||||||
"graph-passes", llvm::cl::value_desc("list"),
|
"graph-passes", llvm::cl::value_desc("list"),
|
||||||
llvm::cl::desc("comma seprarated list of GraphOptimizationPass to run."),
|
llvm::cl::desc("comma separated list of GraphOptimizationPass to run."),
|
||||||
llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory));
|
llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory));
|
||||||
|
|
||||||
class GraphOptByNamePass : public GraphOptPass {
|
class GraphOptByNamePass : public GraphOptPass {
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/strings/escaping.h"
|
#include "absl/strings/escaping.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
@ -75,6 +76,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
@ -500,7 +502,8 @@ Status ImporterBase::GetInputOutputNodes(
|
|||||||
TF_RETURN_IF_ERROR(add_node(input.first));
|
TF_RETURN_IF_ERROR(add_node(input.first));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& output_node_name : specs_.output_arrays) {
|
for (const auto& output : specs_.outputs) {
|
||||||
|
auto output_node_name = std::string(ParseTensorName(output).first);
|
||||||
TF_RETURN_IF_ERROR(add_node(output_node_name));
|
TF_RETURN_IF_ERROR(add_node(output_node_name));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -535,7 +538,7 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
|||||||
auto node_name = node->op_def().name();
|
auto node_name = node->op_def().name();
|
||||||
if (node_name != "Placeholder" && node_name != "LegacyFedInput" &&
|
if (node_name != "Placeholder" && node_name != "LegacyFedInput" &&
|
||||||
node_name != FunctionLibraryDefinition::kArgOp) {
|
node_name != FunctionLibraryDefinition::kArgOp) {
|
||||||
// We do not handle the case where the input node has multple outputs
|
// We do not handle the case where the input node has multiple outputs
|
||||||
if (node->num_outputs() > 1) {
|
if (node->num_outputs() > 1) {
|
||||||
return errors::FailedPrecondition(absl::StrCat(
|
return errors::FailedPrecondition(absl::StrCat(
|
||||||
"Input arrays can only have op with single output. Node op:",
|
"Input arrays can only have op with single output. Node op:",
|
||||||
@ -1588,7 +1591,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
|||||||
llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
|
llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
|
||||||
if (specs.graph_as_function) {
|
if (specs.graph_as_function) {
|
||||||
if (specs.prune_unused_nodes || !specs.inputs.empty() ||
|
if (specs.prune_unused_nodes || !specs.inputs.empty() ||
|
||||||
!specs.output_arrays.empty() || !specs.output_arrays_order.empty())
|
!specs.outputs.empty())
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Pruning of graph is currently unsupported when the main graph is "
|
"Pruning of graph is currently unsupported when the main graph is "
|
||||||
"converted to a function.");
|
"converted to a function.");
|
||||||
@ -1622,7 +1625,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
|||||||
// TODO(prakalps): Refactor to keep attribute strings (tf.entry_function,
|
// TODO(prakalps): Refactor to keep attribute strings (tf.entry_function,
|
||||||
// tf.versions) shared by importer and exporter in a centralized place.
|
// tf.versions) shared by importer and exporter in a centralized place.
|
||||||
// Record the input and output mapping.
|
// Record the input and output mapping.
|
||||||
if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
|
if (!specs.inputs.empty() || !specs.outputs.empty()) {
|
||||||
mlir::Builder b(context);
|
mlir::Builder b(context);
|
||||||
std::string s;
|
std::string s;
|
||||||
llvm::raw_string_ostream ss(s);
|
llvm::raw_string_ostream ss(s);
|
||||||
@ -1632,7 +1635,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
|||||||
",");
|
",");
|
||||||
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
|
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
|
||||||
s.clear();
|
s.clear();
|
||||||
mlir::interleave(specs.output_arrays_order, ss, ",");
|
mlir::interleave(specs.outputs, ss, ",");
|
||||||
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
|
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
|
||||||
|
|
||||||
attrs.push_back(b.getNamedAttr("tf.entry_function",
|
attrs.push_back(b.getNamedAttr("tf.entry_function",
|
||||||
@ -1665,9 +1668,13 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
|
|||||||
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
|
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
|
||||||
absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
|
absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
|
||||||
// Finds out all the input nodes and output nodes.
|
// Finds out all the input nodes and output nodes.
|
||||||
if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
|
absl::flat_hash_set<absl::string_view> output_node_names;
|
||||||
|
for (const auto& output_tensor : specs.outputs) {
|
||||||
|
output_node_names.insert(ParseTensorName(output_tensor).node());
|
||||||
|
}
|
||||||
|
if (!specs.inputs.empty() || !specs.outputs.empty()) {
|
||||||
arg_nodes->resize(specs.inputs.size());
|
arg_nodes->resize(specs.inputs.size());
|
||||||
ret_nodes->resize(specs.output_arrays_order.size());
|
ret_nodes->resize(specs.outputs.size());
|
||||||
|
|
||||||
for (Node* n : GetOrderedNodes()) {
|
for (Node* n : GetOrderedNodes()) {
|
||||||
// Handle inputs/arguments.
|
// Handle inputs/arguments.
|
||||||
@ -1677,17 +1684,17 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle outputs/returns.
|
// Handle outputs/returns.
|
||||||
if (specs.output_arrays.find(n->name()) != specs.output_arrays.end()) {
|
if (output_node_names.contains(n->name())) {
|
||||||
for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
|
for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
|
||||||
std::pair<std::string, std::string> name_and_port =
|
std::pair<std::string, std::string> name_and_port =
|
||||||
absl::StrSplit(specs.output_arrays_order[i], ':');
|
absl::StrSplit(specs.outputs[i], ':');
|
||||||
auto name = name_and_port.first;
|
auto name = name_and_port.first;
|
||||||
if (name != n->name()) continue;
|
if (name != n->name()) continue;
|
||||||
int port = 0;
|
int port = 0;
|
||||||
if (!name_and_port.second.empty() &&
|
if (!name_and_port.second.empty() &&
|
||||||
!absl::SimpleAtoi(name_and_port.second, &port)) {
|
!absl::SimpleAtoi(name_and_port.second, &port)) {
|
||||||
return errors::InvalidArgument("Invalid port specification: ",
|
return errors::InvalidArgument("Invalid port specification: ",
|
||||||
specs.output_arrays_order[i]);
|
specs.outputs[i]);
|
||||||
}
|
}
|
||||||
(*ret_nodes)[i] = {n, port};
|
(*ret_nodes)[i] = {n, port};
|
||||||
}
|
}
|
||||||
@ -1726,10 +1733,10 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||||
ret_types.reserve(specs.output_arrays.size());
|
ret_types.reserve(specs.outputs.size());
|
||||||
for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
|
for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
|
||||||
if (ret_nodes->at(i).node == nullptr) {
|
if (ret_nodes->at(i).node == nullptr) {
|
||||||
return errors::InvalidArgument("Output ", specs.output_arrays_order[i],
|
return errors::InvalidArgument("Output ", specs.outputs[i],
|
||||||
" was not found in graph");
|
" was not found in graph");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -33,19 +33,16 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
Status ParseOutputArrayInfo(absl::string_view array_names,
|
Status ParseOutputArrayInfo(absl::string_view array_names,
|
||||||
absl::flat_hash_set<string>* array,
|
std::vector<string>* outputs) {
|
||||||
std::vector<string>* order) {
|
|
||||||
std::vector<string> output_names = absl::StrSplit(array_names, ',');
|
std::vector<string> output_names = absl::StrSplit(array_names, ',');
|
||||||
return ParseOutputArrayInfo(output_names, array, order);
|
return ParseOutputArrayInfo(output_names, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ParseOutputArrayInfo(const std::vector<string>& output_names,
|
Status ParseOutputArrayInfo(const std::vector<string>& output_names,
|
||||||
absl::flat_hash_set<string>* array,
|
std::vector<string>* outputs) {
|
||||||
std::vector<string>* order) {
|
|
||||||
for (auto& output_name : output_names) {
|
for (auto& output_name : output_names) {
|
||||||
if (output_name.empty()) continue;
|
if (output_name.empty()) continue;
|
||||||
array->insert(string(*absl::StrSplit(output_name, ':').begin()));
|
outputs->push_back(output_name);
|
||||||
order->push_back(output_name);
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -40,11 +40,9 @@ struct GraphImportConfig {
|
|||||||
llvm::MapVector<string, ArrayInfo, llvm::StringMap<unsigned>>;
|
llvm::MapVector<string, ArrayInfo, llvm::StringMap<unsigned>>;
|
||||||
// Maps input node names to node data types and shapes.
|
// Maps input node names to node data types and shapes.
|
||||||
InputArrays inputs;
|
InputArrays inputs;
|
||||||
// Output node names.
|
// name:index strings for the output as specified on the command line.
|
||||||
absl::flat_hash_set<string> output_arrays;
|
std::vector<string> outputs;
|
||||||
// nodes:index strings for the output as specified on the command line.
|
// Setting prune_unused_nodes to true, would prune unreachable nodes if
|
||||||
std::vector<string> output_arrays_order;
|
|
||||||
// setting prune_unused_nodes to true, would prune unreachable nodes if
|
|
||||||
// output_arrays is specified.
|
// output_arrays is specified.
|
||||||
bool prune_unused_nodes = false;
|
bool prune_unused_nodes = false;
|
||||||
// If true, inputs of type LegacyFedInput are replaced with Placeholder ops.
|
// If true, inputs of type LegacyFedInput are replaced with Placeholder ops.
|
||||||
@ -73,12 +71,10 @@ struct GraphExportConfig {
|
|||||||
// Parses the command line flag strings to the specification of nodes in
|
// Parses the command line flag strings to the specification of nodes in
|
||||||
// the Graph.
|
// the Graph.
|
||||||
Status ParseOutputArrayInfo(absl::string_view array_names,
|
Status ParseOutputArrayInfo(absl::string_view array_names,
|
||||||
absl::flat_hash_set<string>* array,
|
std::vector<string>* outputs);
|
||||||
std::vector<string>* order);
|
|
||||||
|
|
||||||
Status ParseOutputArrayInfo(const std::vector<string>& output_names,
|
Status ParseOutputArrayInfo(const std::vector<string>& output_names,
|
||||||
absl::flat_hash_set<string>* array,
|
std::vector<string>* outputs);
|
||||||
std::vector<string>* order);
|
|
||||||
|
|
||||||
// Parses the command line flag strings to the specification of nodes in
|
// Parses the command line flag strings to the specification of nodes in
|
||||||
// the Graph. `data_types` input string can be empty since the flag is optional.
|
// the Graph. `data_types` input string can be empty since the flag is optional.
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||||
@ -63,16 +64,18 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
|||||||
specs.upgrade_legacy = upgrade_legacy;
|
specs.upgrade_legacy = upgrade_legacy;
|
||||||
TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
|
TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
|
||||||
input_shapes, &specs.inputs));
|
input_shapes, &specs.inputs));
|
||||||
TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.output_arrays,
|
TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs));
|
||||||
&specs.output_arrays_order));
|
|
||||||
// TODO(b/142828368): Pruning should not be needed when TF import
|
// TODO(b/142828368): Pruning should not be needed when TF import
|
||||||
// supports importing graphs w/ unregistered ops natively.
|
// supports importing graphs w/ unregistered ops natively.
|
||||||
GraphDef pruned_graph_def;
|
GraphDef pruned_graph_def;
|
||||||
if (specs.prune_unused_nodes) {
|
if (specs.prune_unused_nodes) {
|
||||||
std::vector<string> terminal_nodes(specs.output_arrays.begin(),
|
std::vector<std::string> terminal_nodes;
|
||||||
specs.output_arrays.end());
|
terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size());
|
||||||
for (const auto entry : specs.inputs) {
|
for (const auto& output : specs.outputs) {
|
||||||
terminal_nodes.push_back(entry.first);
|
terminal_nodes.push_back(std::string(ParseTensorName(output).node()));
|
||||||
|
}
|
||||||
|
for (const auto& input : specs.inputs) {
|
||||||
|
terminal_nodes.push_back(input.first);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
|
TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
|
||||||
graphdef, &pruned_graph_def, terminal_nodes));
|
graphdef, &pruned_graph_def, terminal_nodes));
|
||||||
|
@ -36,7 +36,7 @@ xla::StatusOr<xla::Shape> TestShapeRepresentation(const TensorShape& shape,
|
|||||||
return xla_shape;
|
return xla_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerliazedMlirModule) {
|
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
|
||||||
string invalid_mlir_module = "totally @invalid MLIR module {here} <-";
|
string invalid_mlir_module = "totally @invalid MLIR module {here} <-";
|
||||||
std::vector<TensorShape> arg_shapes;
|
std::vector<TensorShape> arg_shapes;
|
||||||
XlaCompiler::CompilationResult compilation_result;
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
@ -101,7 +101,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
|
|||||||
xla::ShapeUtil::MakeTupleShape({output_shape});
|
xla::ShapeUtil::MakeTupleShape({output_shape});
|
||||||
EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape);
|
EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape);
|
||||||
|
|
||||||
// Expect exactly 1 OutputDescrpition.
|
// Expect exactly 1 OutputDescription.
|
||||||
EXPECT_EQ(compilation_result.outputs.size(), 1);
|
EXPECT_EQ(compilation_result.outputs.size(), 1);
|
||||||
const XlaCompiler::OutputDescription& output_desc =
|
const XlaCompiler::OutputDescription& output_desc =
|
||||||
compilation_result.outputs.front();
|
compilation_result.outputs.front();
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Support/ToolUtilities.h" // TF:local_config_mlir
|
||||||
#include "mlir/Support/TranslateClParser.h" // TF:local_config_mlir
|
#include "mlir/Support/TranslateClParser.h" // TF:local_config_mlir
|
||||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||||
@ -40,6 +41,13 @@ static llvm::cl::opt<std::string> output_filename(
|
|||||||
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
||||||
llvm::cl::init("-"));
|
llvm::cl::init("-"));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static llvm::cl::opt<bool> splitInputFile(
|
||||||
|
"split-input-file",
|
||||||
|
llvm::cl::desc("Split the input file into pieces and process each chunk "
|
||||||
|
"independently"),
|
||||||
|
llvm::cl::init(false));
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static llvm::cl::opt<bool> import_saved_model(
|
static llvm::cl::opt<bool> import_saved_model(
|
||||||
"savedmodel-to-mlir",
|
"savedmodel-to-mlir",
|
||||||
@ -85,13 +93,12 @@ int main(int argc, char** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
|
||||||
|
|
||||||
if (import_saved_model) {
|
if (import_saved_model) {
|
||||||
std::unordered_set<std::string> tags =
|
std::unordered_set<std::string> tags =
|
||||||
absl::StrSplit(saved_model_tags, ',');
|
absl::StrSplit(saved_model_tags, ',');
|
||||||
std::vector<std::string> exported_names =
|
std::vector<std::string> exported_names =
|
||||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
|
||||||
auto module = tensorflow::SavedModelToMlirImport(
|
auto module = tensorflow::SavedModelToMlirImport(
|
||||||
input_filename, tags, absl::Span<std::string>(exported_names),
|
input_filename, tags, absl::Span<std::string>(exported_names),
|
||||||
@ -107,12 +114,23 @@ int main(int argc, char** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SourceMgr source_mgr;
|
// Processes the memory buffer with a new MLIRContext.
|
||||||
source_mgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
|
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
|
||||||
mlir::SourceMgrDiagnosticHandler diagnostic_handler(source_mgr, &context);
|
llvm::raw_ostream& os) {
|
||||||
|
llvm::SourceMgr sourceMgr;
|
||||||
|
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context);
|
||||||
|
return (*requested_translation)(sourceMgr, os, &context);
|
||||||
|
};
|
||||||
|
|
||||||
if (failed((*requested_translation)(source_mgr, output->os(), &context)))
|
if (splitInputFile) {
|
||||||
|
if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer,
|
||||||
|
output->os())))
|
||||||
return 1;
|
return 1;
|
||||||
|
} else {
|
||||||
|
if (failed(processBuffer(std::move(input), output->os()))) return 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
output->keep();
|
output->keep();
|
||||||
|
@ -404,6 +404,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"@llvm//:support",
|
"@llvm//:support",
|
||||||
"@local_config_mlir//:Analysis",
|
"@local_config_mlir//:Analysis",
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||||
@ -82,3 +83,4 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Converts the given elements attr to the specified elements type.
|
// Converts the given elements attr to the specified elements type.
|
||||||
@ -27,5 +28,6 @@ namespace xla {
|
|||||||
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||||
mlir::Type new_type);
|
mlir::Type new_type);
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
|
||||||
|
@ -47,13 +47,11 @@ limitations under the License.
|
|||||||
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
|
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
|
||||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc"
|
||||||
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc"
|
||||||
} // namespace mlir
|
namespace xla_hlo {
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace mlir::xla_hlo;
|
|
||||||
|
|
||||||
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
||||||
Attribute value, Type type,
|
Attribute value, Type type,
|
||||||
@ -160,7 +158,7 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) {
|
|||||||
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
|
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
|
||||||
value.isa<IntegerAttr>()) {
|
value.isa<IntegerAttr>()) {
|
||||||
// All XLA types must be tensor types. In the build() method, we want to
|
// All XLA types must be tensor types. In the build() method, we want to
|
||||||
// provide more flexiblity by allowing attributes of scalar types. But we
|
// provide more flexibility by allowing attributes of scalar types. But we
|
||||||
// need to wrap it up with ElementsAttr to construct valid XLA constants.
|
// need to wrap it up with ElementsAttr to construct valid XLA constants.
|
||||||
type = RankedTensorType::get(/*shape=*/{}, value.getType());
|
type = RankedTensorType::get(/*shape=*/{}, value.getType());
|
||||||
value = DenseElementsAttr::get(type.cast<TensorType>(), value);
|
value = DenseElementsAttr::get(type.cast<TensorType>(), value);
|
||||||
@ -212,9 +210,9 @@ void AbsOp::build(Builder* builder, OperationState& result, Value* operand) {
|
|||||||
new_type = operand->getType();
|
new_type = operand->getType();
|
||||||
} else if (shaped_type.hasRank()) {
|
} else if (shaped_type.hasRank()) {
|
||||||
new_type =
|
new_type =
|
||||||
mlir::RankedTensorType::get(shaped_type.getShape(), operand->getType());
|
RankedTensorType::get(shaped_type.getShape(), operand->getType());
|
||||||
} else {
|
} else {
|
||||||
new_type = mlir::UnrankedTensorType::get(operand->getType());
|
new_type = UnrankedTensorType::get(operand->getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
return AbsOp::build(builder, result, new_type, operand);
|
return AbsOp::build(builder, result, new_type, operand);
|
||||||
@ -241,7 +239,7 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
|
|
||||||
// If the operand is constant, we can do the conversion now.
|
// If the operand is constant, we can do the conversion now.
|
||||||
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
|
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
|
||||||
return ::xla::ConvertElementsAttr(elementsAttr,
|
return xla::ConvertElementsAttr(elementsAttr,
|
||||||
getElementTypeOrSelf(getResult()));
|
getElementTypeOrSelf(getResult()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -436,7 +434,7 @@ static LogicalResult Verify(ClampOp op) {
|
|||||||
void ComplexOp::build(Builder* builder, OperationState& state, Value* lhs,
|
void ComplexOp::build(Builder* builder, OperationState& state, Value* lhs,
|
||||||
Value* rhs) {
|
Value* rhs) {
|
||||||
auto type = lhs->getType();
|
auto type = lhs->getType();
|
||||||
auto element_ty = mlir::ComplexType::get(getElementTypeOrSelf(type));
|
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
|
||||||
Type result_ty;
|
Type result_ty;
|
||||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
||||||
result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
|
result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
|
||||||
@ -843,6 +841,70 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand,
|
|||||||
return RankedTensorType::get(shape, ranked_ty.getElementType());
|
return RankedTensorType::get(shape, ranked_ty.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SortOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void SortOp::build(Builder* builder, OperationState& state,
|
||||||
|
ArrayRef<Value*> operands, int64_t dimension,
|
||||||
|
bool is_stable) {
|
||||||
|
state.addOperands(operands);
|
||||||
|
state.addAttribute("dimension", builder->getI64IntegerAttr(dimension));
|
||||||
|
state.addAttribute("is_stable", builder->getBoolAttr(dimension));
|
||||||
|
|
||||||
|
SmallVector<Type, 2> element_types;
|
||||||
|
element_types.reserve(operands.size());
|
||||||
|
for (Value* operand : operands) element_types.push_back(operand->getType());
|
||||||
|
state.addTypes(builder->getTupleType(element_types));
|
||||||
|
|
||||||
|
state.addRegion();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult Verify(SortOp op) {
|
||||||
|
Operation::operand_range operands = op.operands();
|
||||||
|
if (operands.empty()) return op.emitOpError("requires at least one input");
|
||||||
|
|
||||||
|
// TODO(antiagainst): verify partionally dynamic shapes
|
||||||
|
if (llvm::all_of(operands, [](Value* operand) {
|
||||||
|
return operand->getType().cast<ShapedType>().hasRank();
|
||||||
|
})) {
|
||||||
|
ArrayRef<int64_t> input_shape =
|
||||||
|
(*operands.begin())->getType().cast<ShapedType>().getShape();
|
||||||
|
|
||||||
|
if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value* operand) {
|
||||||
|
return operand->getType().cast<ShapedType>().getShape() !=
|
||||||
|
input_shape;
|
||||||
|
}))
|
||||||
|
return op.emitOpError("requires all inputs to have the same dimensions");
|
||||||
|
|
||||||
|
if (op.dimension().getSExtValue() >= input_shape.size())
|
||||||
|
return op.emitOpError(
|
||||||
|
"dimension attribute value must be less than input rank");
|
||||||
|
}
|
||||||
|
|
||||||
|
Block& block = op.comparator().front();
|
||||||
|
size_t num_operands = op.getOperation()->getNumOperands();
|
||||||
|
if (block.getNumArguments() != 2 * num_operands)
|
||||||
|
return op.emitOpError("comparator block should have ")
|
||||||
|
<< 2 * num_operands << " arguments";
|
||||||
|
|
||||||
|
for (auto indexed_operand : llvm::enumerate(operands)) {
|
||||||
|
int index = indexed_operand.index();
|
||||||
|
Type element_type =
|
||||||
|
indexed_operand.value()->getType().cast<ShapedType>().getElementType();
|
||||||
|
Type tensor_type = RankedTensorType::get({}, element_type);
|
||||||
|
for (int i : {2 * index, 2 * index + 1}) {
|
||||||
|
Type arg_type = block.getArgument(i)->getType();
|
||||||
|
if (arg_type != tensor_type)
|
||||||
|
return op.emitOpError("comparator block argument #")
|
||||||
|
<< i << " should be of type " << tensor_type << " but got "
|
||||||
|
<< arg_type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TransposeOp
|
// TransposeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -938,6 +1000,15 @@ void TupleOp::build(Builder* builder, OperationState& result,
|
|||||||
build(builder, result, builder->getTupleType(types), values);
|
build(builder, result, builder->getTupleType(types), values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// UnaryEinsumOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void UnaryEinsumOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
|
results.insert<UnaryEinsumToEinsum>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// CompareOp
|
// CompareOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -990,3 +1061,6 @@ XlaHloDialect::XlaHloDialect(MLIRContext* context)
|
|||||||
// Support unknown operations because not all XLA operations are registered.
|
// Support unknown operations because not all XLA operations are registered.
|
||||||
// allowUnknownOperations();
|
// allowUnknownOperations();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
||||||
|
@ -437,9 +437,6 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO
|
|||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"Builder *builder, OperationState &results, "
|
"Builder *builder, OperationState &results, "
|
||||||
"Value* value, int32_t index">];
|
"Value* value, int32_t index">];
|
||||||
|
|
||||||
// GetTupleElementOp has special conversion logic to HLO.
|
|
||||||
let hasCustomHLOConverter = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
||||||
@ -730,6 +727,43 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral
|
|||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def BASE_EinsumOp {
|
||||||
|
string summary = "Einsum operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns a tensor whose elements are defined by equation, which is written
|
||||||
|
in a shorthand form inspired by the Einstein summation convention.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]> {
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_Tensor:$lhs,
|
||||||
|
HLO_Tensor:$rhs,
|
||||||
|
StrAttr:$einsum_config
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs HLO_Tensor);
|
||||||
|
|
||||||
|
// TODO(hinsu): Canonicalize to lower this client side HLO op to server
|
||||||
|
// side HLO ops.
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]> {
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_Tensor:$operand,
|
||||||
|
StrAttr:$einsum_config
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs HLO_Tensor);
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
|
// UnarayEinsumOp is unconditionally canonicalized to the binary EinsumOp so
|
||||||
|
// the HLO converter shouldn't be invoked.
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
|
def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
@ -834,6 +868,26 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",
|
|||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Variadic<HLO_Tensor>:$operands,
|
||||||
|
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs HLO_TensorOrTuple);
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$comparator);
|
||||||
|
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *builder, OperationState &state, ArrayRef<Value *> operands, "
|
||||||
|
"int64_t dimension, bool is_stable"
|
||||||
|
>];
|
||||||
|
|
||||||
|
// TODO(b/129422361): SortOp has special conversion logic to HLO.
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_ReverseOp: HLO_Op<"reverse",
|
def HLO_ReverseOp: HLO_Op<"reverse",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp {
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
@ -708,7 +708,7 @@ class BASE_HLO_ClampOp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class BASE_HLO_ConcatenateOp {
|
class BASE_HLO_ConcatenateOp {
|
||||||
string summary = "XLA's concantenate op";
|
string summary = "XLA's concatenate op";
|
||||||
|
|
||||||
string description = [{
|
string description = [{
|
||||||
Concatenates a set of tensors along the specified dimension.
|
Concatenates a set of tensors along the specified dimension.
|
||||||
@ -832,6 +832,17 @@ class BASE_HLO_SelectAndScatterOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BASE_HLO_SortOp {
|
||||||
|
string summary = "Sort operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Sorts the given `operands` at the given `dimension` with the given
|
||||||
|
`comparator`.
|
||||||
|
|
||||||
|
See https://www.tensorflow.org/xla/operation_semantics#sort.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
class BASE_HLO_ReverseOp {
|
class BASE_HLO_ReverseOp {
|
||||||
string summary = "Reverse operator";
|
string summary = "Reverse operator";
|
||||||
|
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -51,5 +53,18 @@ DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x,
|
|||||||
return DenseIntElementsAttr::get(type, broadcastDimensions);
|
return DenseIntElementsAttr::get(type, broadcastDimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
||||||
|
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||||
|
|
||||||
|
DenseElementsAttr attr;
|
||||||
|
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||||
|
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
||||||
|
return DenseElementsAttr::get(scalar_ty, value);
|
||||||
|
}
|
||||||
|
auto int_ty = ty.cast<IntegerType>();
|
||||||
|
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||||
|
return DenseElementsAttr::get(scalar_ty, value);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||||
@ -48,6 +49,12 @@ static ElementsAttr getSplat(Builder* b, Value* val, T constant) {
|
|||||||
|
|
||||||
return DenseElementsAttr::get(valType, elementAttr);
|
return DenseElementsAttr::get(valType, elementAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns DenseElementsAttr of rank zero with the given element type and the
|
||||||
|
// value.
|
||||||
|
// Requires `ty` to be either FloatType of IntegerType.
|
||||||
|
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
@ -18,20 +18,23 @@ limitations under the License.
|
|||||||
#ifndef HLO_UTILS
|
#ifndef HLO_UTILS
|
||||||
#define HLO_UTILS
|
#define HLO_UTILS
|
||||||
|
|
||||||
#ifndef OP_BASE
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
#endif // OP_BASE
|
|
||||||
|
|
||||||
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
||||||
|
|
||||||
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
|
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
|
||||||
|
|
||||||
class ConstantSplat<string value> : NativeCodeCall<
|
class ConstantSplat<string value> : NativeCodeCall<
|
||||||
"getSplat(&$_builder, $0, " # value # ")">;
|
"xla::getSplat(&$_builder, $0, " # value # ")">;
|
||||||
|
|
||||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||||
|
|
||||||
def BinBroadcastDimensions : NativeCodeCall<
|
def BinBroadcastDimensions : NativeCodeCall<
|
||||||
"getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
|
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
|
||||||
|
|
||||||
|
// Here, the element type can be any integer or float type. But, note that only
|
||||||
|
// 32 bit integers are supported for the value.
|
||||||
|
class GetScalarOfType<int value> : NativeCodeCall<
|
||||||
|
"xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||||
|
|
||||||
#endif // HLO_UTILS
|
#endif // HLO_UTILS
|
||||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
@ -77,6 +78,10 @@ static double ConvertAPFloat(llvm::APFloat value) {
|
|||||||
return value.convertToDouble();
|
return value.convertToDouble();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static absl::string_view ConvertStringRef(mlir::StringRef value) {
|
||||||
|
return {value.data(), value.size()};
|
||||||
|
}
|
||||||
|
|
||||||
static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
|
static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
|
||||||
auto values = attr.getValues<int64>();
|
auto values = attr.getValues<int64>();
|
||||||
return {values.begin(), values.end()};
|
return {values.begin(), values.end()};
|
||||||
@ -494,13 +499,6 @@ LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ExportXlaOp(GetTupleElementOp op, OpLoweringContext ctx) {
|
|
||||||
auto& value_map = *ctx.values;
|
|
||||||
value_map[op] = xla::GetTupleElement(value_map[op.getOperand()],
|
|
||||||
op.index().getSExtValue());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
|
LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
|
||||||
auto& value_map = *ctx.values;
|
auto& value_map = *ctx.values;
|
||||||
value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
|
value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
|
||||||
@ -626,12 +624,30 @@ LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
|
||||||
|
xla::XlaComputation comparator;
|
||||||
|
if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
|
||||||
|
&comparator)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto& value_map = *ctx.values;
|
||||||
|
value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator,
|
||||||
|
op.dimension().getSExtValue(), op.is_stable());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
|
LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
|
||||||
auto& value_map = *ctx.values;
|
auto& value_map = *ctx.values;
|
||||||
value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx));
|
value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
|
||||||
|
// Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
|
||||||
|
// operands.
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
|
LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
|
||||||
xla::XlaComputation condition;
|
xla::XlaComputation condition;
|
||||||
xla::XlaComputation body;
|
xla::XlaComputation body;
|
||||||
@ -773,7 +789,7 @@ LogicalResult ConvertToHloModule::LowerFunctionCall(
|
|||||||
LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
|
LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
|
||||||
if (lowered_computation_.count(f)) return success();
|
if (lowered_computation_.count(f)) return success();
|
||||||
if (f.getBlocks().size() != 1) {
|
if (f.getBlocks().size() != 1) {
|
||||||
return f.emitError("only single block Function suppored");
|
return f.emitError("only single block Function supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a sub-builder if this is not the main function.
|
// Create a sub-builder if this is not the main function.
|
||||||
|
@ -32,17 +32,20 @@ using llvm::raw_ostream;
|
|||||||
using llvm::RecordKeeper;
|
using llvm::RecordKeeper;
|
||||||
using llvm::StringRef;
|
using llvm::StringRef;
|
||||||
using mlir::interleaveComma;
|
using mlir::interleaveComma;
|
||||||
|
using mlir::tblgen::Attribute;
|
||||||
using mlir::tblgen::NamedAttribute;
|
using mlir::tblgen::NamedAttribute;
|
||||||
using mlir::tblgen::NamedTypeConstraint;
|
using mlir::tblgen::NamedTypeConstraint;
|
||||||
using mlir::tblgen::Operator;
|
using mlir::tblgen::Operator;
|
||||||
|
|
||||||
static std::string GetDefaultAttrExport(
|
static std::string GetDefaultAttrExport(
|
||||||
const mlir::tblgen::NamedAttribute& named_attr) {
|
const mlir::tblgen::NamedAttribute& named_attr) {
|
||||||
auto storage_type = named_attr.attr.getStorageType();
|
Attribute attr = named_attr.attr;
|
||||||
|
StringRef storage_type = attr.getStorageType();
|
||||||
// For some attribute types we have a general conversion, so use that.
|
// For some attribute types we have a general conversion, so use that.
|
||||||
if (storage_type.endswith("IntegerAttr") ||
|
if (!attr.isEnumAttr() && (storage_type.endswith("IntegerAttr") ||
|
||||||
storage_type.endswith("FloatAttr")) {
|
storage_type.endswith("FloatAttr") ||
|
||||||
return "Convert" + named_attr.attr.getReturnType().str();
|
storage_type.endswith("StringAttr"))) {
|
||||||
|
return "Convert" + attr.getReturnType().str();
|
||||||
}
|
}
|
||||||
return "Convert_" + named_attr.name.str();
|
return "Convert_" + named_attr.name.str();
|
||||||
}
|
}
|
||||||
|
@ -48,3 +48,11 @@ func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f
|
|||||||
// CHECK: return %arg0
|
// CHECK: return %arg0
|
||||||
return %2 : tensor<4xcomplex<f32>>
|
return %2 : tensor<4xcomplex<f32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @unary_einsum
|
||||||
|
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||||
|
// CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
|
||||||
|
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||||
|
return %0 : tensor<2x2xf32>
|
||||||
|
}
|
||||||
|
@ -180,6 +180,27 @@ func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
|||||||
return %0: tensor<?xi1>
|
return %0: tensor<?xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @bitwise_or
|
||||||
|
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
// CHECK-NEXT: xla_hlo.or
|
||||||
|
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
return %0: tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @bitwise_or_broadcast
|
||||||
|
func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
|
||||||
|
// CHECK-NEXT: xla_hlo.or
|
||||||
|
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||||
|
return %0: tensor<1x4xi8>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @bitwise_or_dynamic
|
||||||
|
func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
|
||||||
|
// CHECK-NEXT: xla_hlo.or
|
||||||
|
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||||
|
return %0: tensor<?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @pow
|
// CHECK-LABEL: func @pow
|
||||||
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
// CHECK-NEXT: xla_hlo.pow
|
// CHECK-NEXT: xla_hlo.pow
|
||||||
@ -194,6 +215,20 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|||||||
return %0: tensor<?xf32>
|
return %0: tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @einsum
|
||||||
|
func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> {
|
||||||
|
// CHECK: xla_hlo.einsum
|
||||||
|
%0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
|
||||||
|
return %0: tensor<2x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @unary_einsum
|
||||||
|
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||||
|
// CHECK: xla_hlo.unary_einsum
|
||||||
|
%0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||||
|
return %0: tensor<2x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @floordiv_broadcast_i32
|
// CHECK-LABEL: func @floordiv_broadcast_i32
|
||||||
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||||
// CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0>
|
// CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0>
|
||||||
@ -589,6 +624,17 @@ func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
|
|||||||
return %1 : tensor<6x9xf32>
|
return %1 : tensor<6x9xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @padv2_i32_paddings
|
||||||
|
func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
|
||||||
|
%padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32>
|
||||||
|
// CHECK: "xla_hlo.pad"(%arg0, %arg1) {
|
||||||
|
// CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>,
|
||||||
|
// CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>,
|
||||||
|
// CHECK-SAME: interior_padding = dense<0> : tensor<2xi64>
|
||||||
|
%1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<6x9xf32>
|
||||||
|
return %1 : tensor<6x9xf32>
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Identity op legalizations.
|
// Identity op legalizations.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1888,3 +1934,42 @@ func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf
|
|||||||
// CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
|
// CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
|
||||||
return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
|
return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// tf.TopKV2 legalization
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// CHECK-LABEL: topk_v2_non_const_k
|
||||||
|
func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
|
||||||
|
// CHECK: tf.TopKV2
|
||||||
|
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
|
||||||
|
return %0#0, %0#1: tensor<?xf32>, tensor<?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: topk_v2_unknown_input_last_dim
|
||||||
|
func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) {
|
||||||
|
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: tf.TopKV2
|
||||||
|
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>)
|
||||||
|
return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: topk_v2
|
||||||
|
// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32>
|
||||||
|
func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
|
||||||
|
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
|
||||||
|
// CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64}
|
||||||
|
// CHECK-NEXT: %[[SORT:.*]] = "xla_hlo.sort"(%[[INPUT]], %[[IOTA]]) ( {
|
||||||
|
// CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f32>, %[[RHS:.*]]: tensor<f32>, %{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
|
||||||
|
// CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"}
|
||||||
|
// CHECK-NEXT: "xla_hlo.return"(%[[CMP]])
|
||||||
|
// CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||||
|
// CHECK-NEXT: %[[TUPL0:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32}
|
||||||
|
// CHECK-NEXT: %[[TUPL1:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32}
|
||||||
|
// CHECK-NEXT: %[[VAL:.*]] = "xla_hlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||||
|
// CHECK-NEXT: %[[IDX:.*]] = "xla_hlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||||
|
// CHECK-NEXT: return %[[VAL]], %[[IDX]]
|
||||||
|
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
|
||||||
|
return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
|
||||||
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s
|
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s
|
||||||
|
|
||||||
#map0 = (d0, d1) -> (d0, d1)
|
#map0 = (d0, d1) -> (d0, d1)
|
||||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], n_loop_types = [2, 0, 0], n_views = [2, 1]}
|
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"], n_views = [2, 1]}
|
||||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||||
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%temp_result = alloc() {temp = true} : memref<2x2xf32>
|
%temp_result = alloc() {temp = true} : memref<2x2xf32>
|
||||||
|
@ -416,3 +416,98 @@ func @constants() -> () {
|
|||||||
%3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor<i32>} : () -> (tensor<*xi32>)
|
%3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor<i32>} : () -> (tensor<*xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||||
|
// CHECK: xla_hlo.sort
|
||||||
|
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||||
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
|
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @sort_no_operands() {
|
||||||
|
// expected-error @+1 {{op requires at least one input}}
|
||||||
|
%0 = "xla_hlo.sort"() ( {
|
||||||
|
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
|
||||||
|
%7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
|
}) {dimension = 1 : i64, is_stable = true} : () -> tuple<>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
||||||
|
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||||
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
|
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
||||||
|
// expected-error @+1 {{comparator block argument #0 should be of type 'tensor<f32>' but got 'tensor<i32>'}}
|
||||||
|
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
|
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) {
|
||||||
|
// expected-error @+1 {{op requires all inputs to have the same dimensions}}
|
||||||
|
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||||
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
|
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||||
|
// expected-error @+1 {{op dimension attribute value must be less than input rank}}
|
||||||
|
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||||
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
|
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
|
}) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||||
|
// expected-error @+1 {{op comparator block should have 4 arguments}}
|
||||||
|
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||||
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
|
||||||
|
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||||
|
// expected-error @+1 {{op comparator block argument #3 should be of type 'tensor<i32>' but got 'tensor<f32>'}}
|
||||||
|
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||||
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
|
||||||
|
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -1,26 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
|
||||||
%0 = "xla_hlo.all_reduce"(%arg0) ({
|
|
||||||
// Perform max reduction inside the region
|
|
||||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
|
||||||
%max = xla_hlo.max %lhs, %rhs : tensor<f32>
|
|
||||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
|
||||||
})
|
|
||||||
{
|
|
||||||
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
|
|
||||||
channel_id = {
|
|
||||||
handle = 5 : i64,
|
|
||||||
type = 2 : i64
|
|
||||||
}
|
|
||||||
} : (tensor<10xf32>) -> tensor<10xf32>
|
|
||||||
return %0 : tensor<10xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
|
|
||||||
// CHECK-LABEL: ENTRY
|
|
||||||
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
|
|
||||||
// CHECK-SAME: channel_id=5
|
|
||||||
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
|
|
||||||
// CHECK-SAME: to_apply=%[[COMPUTATION]]
|
|
@ -1,15 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
|
|
||||||
%0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
|
||||||
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: ENTRY
|
|
||||||
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
|
|
||||||
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
|
|
||||||
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
|
|
||||||
// CHECK: [[VAL_4:%.*]] = f32[2] parameter(3)
|
|
||||||
// CHECK: [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4)
|
|
||||||
// CHECK-LABEL: ROOT
|
|
||||||
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0
|
|
@ -1,13 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
|
|
||||||
%0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
|
||||||
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: ENTRY
|
|
||||||
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
|
|
||||||
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
|
|
||||||
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
|
|
||||||
// CHECK-LABEL: ROOT
|
|
||||||
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3
|
|
@ -1,23 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
module {
|
|
||||||
func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
|
||||||
// CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
|
|
||||||
// CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
|
|
||||||
// CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
|
||||||
%0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
|
|
||||||
|
|
||||||
// CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
|
||||||
%1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
|
||||||
|
|
||||||
// CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
|
||||||
%2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
|
||||||
|
|
||||||
// CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
|
||||||
%3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
|
|
||||||
|
|
||||||
// CHECK-LABEL: ROOT
|
|
||||||
// CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
|
|
||||||
return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,26 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
// CHECK-LABEL: ENTRY %main.13 (Arg_0.1: s32[1,4], Arg_1.2: s32[2,4], Arg_2.3: s32[2,3,4]) -> s32[2,3,4] {
|
|
||||||
func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
|
|
||||||
// Same rank degenerate broadcast
|
|
||||||
// CHECK-NEXT: %Arg_0.1 = s32[1,4] parameter(0)
|
|
||||||
// CHECK-NEXT: %reshape.4 = s32[4] reshape(s32[1,4] %Arg_0.1)
|
|
||||||
// CHECK-NEXT: %broadcast.5 = s32[2,4] broadcast(s32[4] %reshape.4)
|
|
||||||
// CHECK-NEXT: %Arg_1.2 = s32[2,4] parameter(1)
|
|
||||||
// CHECK-NEXT: %add.6 = s32[2,4] add(s32[2,4] %broadcast.5, s32[2,4] %Arg_1.2)
|
|
||||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
|
||||||
|
|
||||||
// Broadcast up rank
|
|
||||||
// CHECK-NEXT: %broadcast.7 = s32[2,3,4] broadcast(s32[2,4] %Arg_1.2), dimensions={0,2}
|
|
||||||
// CHECK-NEXT: %Arg_2.3 = s32[2,3,4] parameter(2)
|
|
||||||
// CHECK-NEXT: %add.8 = s32[2,3,4] add(s32[2,3,4] %broadcast.7, s32[2,3,4] %Arg_2.3)
|
|
||||||
%1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
|
||||||
|
|
||||||
// Broadcast up rank + degenerate broadcast
|
|
||||||
// CHECK-NEXT: %broadcast.9 = s32[2,1,4] broadcast(s32[1,4] %Arg_0.1), dimensions={1,2}
|
|
||||||
// CHECK-NEXT: %reshape.10 = s32[2,4] reshape(s32[2,1,4] %broadcast.9)
|
|
||||||
// CHECK-NEXT: %broadcast.11 = s32[2,3,4] broadcast(s32[2,4] %reshape.10), dimensions={0,2}
|
|
||||||
// CHECK-NEXT: ROOT %add.12 = s32[2,3,4] add(s32[2,3,4] %broadcast.11, s32[2,3,4] %Arg_2.3)
|
|
||||||
%2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
|
||||||
return %2 : tensor<2x3x4xi32>
|
|
||||||
}
|
|
@ -1,9 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
// CHECK-LABEL: ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[1,2,3,4] {
|
|
||||||
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
|
|
||||||
// CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
|
|
||||||
// CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
|
|
||||||
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
|
|
||||||
return %0 : tensor<1x2x3x4xi32>
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
|
|
||||||
%result = "xla_hlo.broadcast_in_dim"(%arg0) {
|
|
||||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
|
||||||
} : (tensor<1xf32>) -> tensor<1x10xf32>
|
|
||||||
return %result : tensor<1x10xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK: ENTRY %main.3 ([[ARG0:.*]]: f32[1]) -> f32[1,10] {
|
|
||||||
// CHECK: %[[ARG0]] = f32[1] parameter(0)
|
|
||||||
// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] %[[ARG0]]), dimensions={0}
|
|
@ -1,30 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
|
||||||
%0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
|
||||||
%1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
|
||||||
return %1 : tensor<4xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
|
||||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
|
||||||
return %0 : tensor<4xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK: [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s32[4], [[ARG_2:.*]]: s32[4]) -> s32[4] {
|
|
||||||
// CHECK: %[[ARG_1]] = s32[4] parameter(0)
|
|
||||||
// CHECK: %[[ARG_2]] = s32[4] parameter(1)
|
|
||||||
// CHECK-LABEL: ROOT
|
|
||||||
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]])
|
|
||||||
|
|
||||||
// CHECK: [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] {
|
|
||||||
// CHECK: %[[ARG_3]] = s32[4] parameter(0)
|
|
||||||
// CHECK: %[[ARG_4]] = s32[4] parameter(1)
|
|
||||||
// CHECK-LABEL: ROOT
|
|
||||||
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]])
|
|
||||||
|
|
||||||
// CHECK: ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] {
|
|
||||||
// CHECK: %[[ARG]] = s32[4] parameter(0)
|
|
||||||
// CHECK: [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]]
|
|
||||||
// CHECK-LABEL: ROOT
|
|
||||||
// CHECK-SAME: s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]]
|
|
@ -1,24 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
|
|
||||||
%0:2 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>)
|
|
||||||
return %0#0, %0#1 : tensor<4xi32>, tensor<4xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
|
|
||||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
|
||||||
%1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
|
||||||
return %0, %1 : tensor<4xi32>, tensor<4xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get name of callee computation
|
|
||||||
// CHECK: [[CALLEE:%.*]] ({{.*}}) -> ({{.*}}) {
|
|
||||||
|
|
||||||
// CHECK-LABEL: ENTRY
|
|
||||||
// CHECK-SAME: [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) {
|
|
||||||
// CHECK: %[[ARG]] = s32[4] parameter(0)
|
|
||||||
// CHECK: [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]]
|
|
||||||
// CHECK: [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0
|
|
||||||
// CHECK: [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1
|
|
||||||
// CHECK-LABEL: ROOT
|
|
||||||
// CHECK-SAME: (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]])
|
|
@ -1,17 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0 : tensor<5x2xf32>,
|
|
||||||
%arg1 : tensor<5x5xf32>,
|
|
||||||
%arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
|
|
||||||
%result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) {
|
|
||||||
dimension = 1 : i64
|
|
||||||
} : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
|
|
||||||
return %result : tensor<5x14xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: main
|
|
||||||
// CHECK: %[[ARG0:.*]] = f32[5,2] parameter(0)
|
|
||||||
// CHECK: %[[ARG1:.*]] = f32[5,5] parameter(1)
|
|
||||||
// CHECK: %[[ARG2:.*]] = f32[5,7] parameter(2)
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1}
|
|
@ -42,13 +42,13 @@ func @main(%arg0: tensor<f32>) -> tuple<tensor<f32>> {
|
|||||||
|
|
||||||
// CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]]
|
// CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]]
|
||||||
%2 = "xla_hlo.conditional"(%0, %1, %1) ( {
|
%2 = "xla_hlo.conditional"(%0, %1, %1) ( {
|
||||||
^bb0(%arg1: tuple<tensor<f32>>): // no predecessors
|
^bb0(%arg1: tuple<tensor<f32>>):
|
||||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||||
%7 = "xla_hlo.log"(%6) : (tensor<f32>) -> tensor<f32>
|
%7 = "xla_hlo.log"(%6) : (tensor<f32>) -> tensor<f32>
|
||||||
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||||
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
|
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
|
||||||
}, {
|
}, {
|
||||||
^bb0(%arg1: tuple<tensor<f32>>): // no predecessors
|
^bb0(%arg1: tuple<tensor<f32>>):
|
||||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||||
%7 = "xla_hlo.exp"(%6) : (tensor<f32>) -> tensor<f32>
|
%7 = "xla_hlo.exp"(%6) : (tensor<f32>) -> tensor<f32>
|
||||||
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||||
|
@ -1,30 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure
|
|
||||||
|
|
||||||
// CHECK-LABEL: ENTRY %main
|
|
||||||
func @main() -> tensor<2x2x1x1xf32> {
|
|
||||||
// CHECK: constant.{{.*}} = s64[] constant(1)
|
|
||||||
%cst = constant dense<1> : tensor<i64>
|
|
||||||
// CHECK: constant.{{.*}} = f32[2,2,1,1]
|
|
||||||
// CHECK-SAME: { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } }
|
|
||||||
%cst_0 = constant dense<
|
|
||||||
[[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]]
|
|
||||||
> : tensor<2x2x1x1xf32>
|
|
||||||
|
|
||||||
// CHECK: s32[1] constant({1})
|
|
||||||
%cst_1 = constant dense<1> : tensor<1xi32>
|
|
||||||
|
|
||||||
// CHECK: %[[C:.*]] = s32[] constant(1)
|
|
||||||
// CHECK: s32[10] broadcast(s32[] %[[C]])
|
|
||||||
%cst_2 = constant dense<1> : tensor<10xi32>
|
|
||||||
|
|
||||||
// CHECK: s32[4] constant({1, 2, 3, 4})
|
|
||||||
%cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
|
||||||
|
|
||||||
// CHECK: s32[2,2] constant({ { 1, 2 }, { 3, 4 } })
|
|
||||||
%cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
|
||||||
|
|
||||||
// CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
|
|
||||||
%cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
|
|
||||||
|
|
||||||
return %cst_0 : tensor<2x2x1x1xf32>
|
|
||||||
}
|
|
@ -1,31 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
|
|
||||||
%result = "xla_hlo.conv"(%arg0, %arg1) {
|
|
||||||
batch_group_count = 1 : i64,
|
|
||||||
dimension_numbers = {
|
|
||||||
input_batch_dimension = 0 : i64,
|
|
||||||
input_feature_dimension = 3 : i64,
|
|
||||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
|
||||||
kernel_input_feature_dimension = 3 : i64,
|
|
||||||
kernel_output_feature_dimension = 2 : i64,
|
|
||||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
|
||||||
output_batch_dimension = 0 : i64,
|
|
||||||
output_feature_dimension = 3 : i64,
|
|
||||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
|
||||||
},
|
|
||||||
feature_group_count = 1 : i64,
|
|
||||||
lhs_dilation = dense<1> : tensor<2xi64>,
|
|
||||||
padding = dense<2> : tensor<2x2xi64>,
|
|
||||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
|
||||||
window_strides = dense<1> : tensor<2xi64>
|
|
||||||
} : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32>
|
|
||||||
return %result : tensor<100x28x28x1xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: main
|
|
||||||
// CHECK: %[[ARG0:.*]] = f32[100,26,26,32] parameter(0)
|
|
||||||
// CHECK: %[[ARG1:.*]] = f32[3,3,1,32] parameter(1)
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]),
|
|
||||||
// CHECK-SAME: window={size=3x3 pad=2_2x2_2},
|
|
||||||
// CHECK-SAME: dim_labels=b01f_01oi->b01f
|
|
@ -1,10 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
|
||||||
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
|
||||||
return %0 : tensor<2xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK: ENTRY %main
|
|
||||||
// CHECK: %[[ARG:.*]] = s32[2] parameter(0)
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]])
|
|
@ -1,10 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
|
||||||
%0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
|
|
||||||
return %0 : tensor<2xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK: ENTRY %main
|
|
||||||
// CHECK: [[ARG:%.*]] = s32[2] parameter(0)
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]])
|
|
@ -1,16 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
|
||||||
%0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
|
|
||||||
%1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
|
|
||||||
return %1 : tensor<10xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK: %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[]
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
|
|
||||||
|
|
||||||
// CHECK: ENTRY %main
|
|
||||||
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
|
|
||||||
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
|
|
||||||
// CHECK-SAME: to_apply=%[[SUM_COMPUTATION]]
|
|
640
tensorflow/compiler/mlir/xla/tests/translate/export.mlir
Normal file
640
tensorflow/compiler/mlir/xla/tests/translate/export.mlir
Normal file
@ -0,0 +1,640 @@
|
|||||||
|
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||||
|
%0 = "xla_hlo.all_reduce"(%arg0) ({
|
||||||
|
// Perform max reduction inside the region
|
||||||
|
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||||
|
%max = xla_hlo.max %lhs, %rhs : tensor<f32>
|
||||||
|
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||||
|
})
|
||||||
|
{
|
||||||
|
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
|
||||||
|
channel_id = {
|
||||||
|
handle = 5 : i64,
|
||||||
|
type = 2 : i64
|
||||||
|
}
|
||||||
|
} : (tensor<10xf32>) -> tensor<10xf32>
|
||||||
|
return %0 : tensor<10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
|
||||||
|
// CHECK-SAME: channel_id=5
|
||||||
|
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
|
||||||
|
// CHECK-SAME: to_apply=%[[COMPUTATION]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
|
||||||
|
%0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||||
|
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
|
||||||
|
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
|
||||||
|
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
|
||||||
|
// CHECK: [[VAL_4:%.*]] = f32[2] parameter(3)
|
||||||
|
// CHECK: [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4)
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
|
||||||
|
%0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||||
|
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
|
||||||
|
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
|
||||||
|
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||||
|
// CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
|
||||||
|
// CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
|
||||||
|
// CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||||
|
%0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
|
||||||
|
|
||||||
|
// CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||||
|
%1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||||
|
|
||||||
|
// CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||||
|
%2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||||
|
|
||||||
|
// CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||||
|
%3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
|
||||||
|
return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
|
||||||
|
// Same rank degenerate broadcast
|
||||||
|
// CHECK: [[ARG_0:%.*]] = s32[1,4] parameter(0)
|
||||||
|
// CHECK-NEXT: [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]])
|
||||||
|
// CHECK-NEXT: [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]])
|
||||||
|
// CHECK-NEXT: [[ARG_1:%.*]] = s32[2,4] parameter(1)
|
||||||
|
// CHECK-NEXT: s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]])
|
||||||
|
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||||
|
|
||||||
|
// Broadcast up rank
|
||||||
|
// CHECK-NEXT: [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2}
|
||||||
|
// CHECK-NEXT: [[ARG_2:%.*]] = s32[2,3,4] parameter(2)
|
||||||
|
// CHECK-NEXT: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]])
|
||||||
|
%1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
||||||
|
|
||||||
|
// Broadcast up rank + degenerate broadcast
|
||||||
|
// CHECK-NEXT: [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2}
|
||||||
|
// CHECK-NEXT: [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]])
|
||||||
|
// CHECK-NEXT: [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2}
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]])
|
||||||
|
%2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
||||||
|
return %2 : tensor<2x3x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
|
||||||
|
// CHECK: [[ARG:%.*]] = s32[4] parameter(0)
|
||||||
|
// CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] [[ARG]]), dimensions={3}
|
||||||
|
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
|
||||||
|
return %0 : tensor<1x2x3x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
|
||||||
|
%result = "xla_hlo.broadcast_in_dim"(%arg0) {
|
||||||
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
|
} : (tensor<1xf32>) -> tensor<1x10xf32>
|
||||||
|
return %result : tensor<1x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: [[ARG:%.*]] = f32[1] parameter(0)
|
||||||
|
// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] [[ARG]]), dimensions={0}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
%1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
return %1 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||||
|
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
return %0 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s32[4], [[ARG_2:.*]]: s32[4]) -> s32[4] {
|
||||||
|
// CHECK: %[[ARG_1]] = s32[4] parameter(0)
|
||||||
|
// CHECK: %[[ARG_2]] = s32[4] parameter(1)
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]])
|
||||||
|
|
||||||
|
// CHECK: [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] {
|
||||||
|
// CHECK: %[[ARG_3]] = s32[4] parameter(0)
|
||||||
|
// CHECK: %[[ARG_4]] = s32[4] parameter(1)
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]])
|
||||||
|
|
||||||
|
// CHECK: ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] {
|
||||||
|
// CHECK: %[[ARG]] = s32[4] parameter(0)
|
||||||
|
// CHECK: [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]]
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
|
||||||
|
%0:2 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>)
|
||||||
|
return %0#0, %0#1 : tensor<4xi32>, tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
|
||||||
|
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
%1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
return %0, %1 : tensor<4xi32>, tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get name of callee computation
|
||||||
|
// CHECK: [[CALLEE:%.*]] ({{.*}}) -> ({{.*}}) {
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK-SAME: [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) {
|
||||||
|
// CHECK: %[[ARG]] = s32[4] parameter(0)
|
||||||
|
// CHECK: [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]]
|
||||||
|
// CHECK: [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0
|
||||||
|
// CHECK: [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]])
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0 : tensor<5x2xf32>,
|
||||||
|
%arg1 : tensor<5x5xf32>,
|
||||||
|
%arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
|
||||||
|
%result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) {
|
||||||
|
dimension = 1 : i64
|
||||||
|
} : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
|
||||||
|
return %result : tensor<5x14xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[5,2] parameter(0)
|
||||||
|
// CHECK: %[[ARG1:.*]] = f32[5,5] parameter(1)
|
||||||
|
// CHECK: %[[ARG2:.*]] = f32[5,7] parameter(2)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main() -> tensor<2x2x1x1xf32> {
|
||||||
|
// CHECK: constant.{{.*}} = s64[] constant(1)
|
||||||
|
%cst = constant dense<1> : tensor<i64>
|
||||||
|
// CHECK: constant.{{.*}} = f32[2,2,1,1]
|
||||||
|
// CHECK-SAME: { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } }
|
||||||
|
%cst_0 = constant dense<
|
||||||
|
[[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]]
|
||||||
|
> : tensor<2x2x1x1xf32>
|
||||||
|
|
||||||
|
// CHECK: s32[1] constant({1})
|
||||||
|
%cst_1 = constant dense<1> : tensor<1xi32>
|
||||||
|
|
||||||
|
// CHECK: %[[C:.*]] = s32[] constant(1)
|
||||||
|
// CHECK: s32[10] broadcast(s32[] %[[C]])
|
||||||
|
%cst_2 = constant dense<1> : tensor<10xi32>
|
||||||
|
|
||||||
|
// CHECK: s32[4] constant({1, 2, 3, 4})
|
||||||
|
%cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
||||||
|
|
||||||
|
// CHECK: s32[2,2] constant({ { 1, 2 }, { 3, 4 } })
|
||||||
|
%cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
||||||
|
|
||||||
|
// CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
|
||||||
|
%cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
|
||||||
|
|
||||||
|
return %cst_0 : tensor<2x2x1x1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
|
||||||
|
%result = "xla_hlo.conv"(%arg0, %arg1) {
|
||||||
|
batch_group_count = 1 : i64,
|
||||||
|
dimension_numbers = {
|
||||||
|
input_batch_dimension = 0 : i64,
|
||||||
|
input_feature_dimension = 3 : i64,
|
||||||
|
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
kernel_input_feature_dimension = 3 : i64,
|
||||||
|
kernel_output_feature_dimension = 2 : i64,
|
||||||
|
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
output_batch_dimension = 0 : i64,
|
||||||
|
output_feature_dimension = 3 : i64,
|
||||||
|
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||||
|
},
|
||||||
|
feature_group_count = 1 : i64,
|
||||||
|
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
|
padding = dense<2> : tensor<2x2xi64>,
|
||||||
|
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
|
window_strides = dense<1> : tensor<2xi64>
|
||||||
|
} : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32>
|
||||||
|
return %result : tensor<100x28x28x1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[100,26,26,32] parameter(0)
|
||||||
|
// CHECK: %[[ARG1:.*]] = f32[3,3,1,32] parameter(1)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]),
|
||||||
|
// CHECK-SAME: window={size=3x3 pad=2_2x2_2},
|
||||||
|
// CHECK-SAME: dim_labels=b01f_01oi->b01f
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||||
|
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: %[[ARG:.*]] = s32[2] parameter(0)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]])
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||||
|
%0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
|
||||||
|
return %0 : tensor<2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: [[ARG:%.*]] = s32[2] parameter(0)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]])
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||||
|
%0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
|
||||||
|
%1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
|
||||||
|
return %1 : tensor<10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[]
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
|
||||||
|
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
|
||||||
|
// CHECK-SAME: to_apply=%[[SUM_COMPUTATION]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
|
||||||
|
// Simple einsum is lowered to HLO dot op.
|
||||||
|
// CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||||
|
%0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
|
||||||
|
return %0 : tensor<3x5xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
|
||||||
|
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
|
||||||
|
// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
||||||
|
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = (f32[], s32[]) parameter(0)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main() -> tensor<1x10xf32> {
|
||||||
|
%result = "xla_hlo.iota"() {
|
||||||
|
iota_dimension = 1 : i64
|
||||||
|
} : () -> tensor<1x10xf32>
|
||||||
|
return %result : tensor<1x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[1,10] iota(), iota_dimension=1
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
|
||||||
|
%0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
|
||||||
|
return %0 : tensor<13x19xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: [[ARG:%.*]] = f32[4,6] parameter(0)
|
||||||
|
// CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1)
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
|
||||||
|
%result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
|
||||||
|
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
|
||||||
|
%fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
%imax = "xla_hlo.max"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||||
|
"xla_hlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
|
||||||
|
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
|
||||||
|
return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[REGION:region_[0-9]+]]
|
||||||
|
// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[])
|
||||||
|
// CHECK: %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]])
|
||||||
|
// CHECK: %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]])
|
||||||
|
// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]])
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG1:.*]]: s32[1,10], [[ARG2:.*]]: f32[], [[ARG3:.*]]: s32[]) -> (f32[1], s32[1])
|
||||||
|
// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %[[ARG0]], s32[1,10] %[[ARG1]], f32[] %[[ARG2]], s32[] %[[ARG3]]), dimensions={1}, to_apply=%[[REGION]]
|
||||||
|
// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0
|
||||||
|
// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]])
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> {
|
||||||
|
%0 = xla_hlo.constant dense<-2147483648> : tensor<i32>
|
||||||
|
%1 = "xla_hlo.reduce_window"(%arg0, %0) ( {
|
||||||
|
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
|
||||||
|
%2 = xla_hlo.max %arg1, %arg2 : tensor<i32>
|
||||||
|
"xla_hlo.return"(%2) : (tensor<i32>) -> ()
|
||||||
|
}) {
|
||||||
|
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>,
|
||||||
|
padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>,
|
||||||
|
base_dilations = dense<[1, 1, 1, 1]> : tensor<4xi64>,
|
||||||
|
window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>
|
||||||
|
} : (tensor<2x17x31x7xi32>, tensor<i32>) -> tensor<2x3x5x7xi32>
|
||||||
|
return %1 : tensor<2x3x5x7xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[MAX_COMPUTATION:.*]] ([[ARG0:.*]]: s32[], [[ARG1:.*]]: s32[]) -> s32[]
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = s32[] maximum(s32[] %[[ARG0]], s32[] %[[ARG1]])
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK-DAG: %[[ARG0:.*]] = s32[2,17,31,7] parameter(0)
|
||||||
|
// CHECK-DAG: %[[INIT:.*]] = s32[] constant(-2147483648)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = s32[2,5,8,7] reduce-window(s32[2,17,31,7] %[[ARG0]], s32[] %constant.2),
|
||||||
|
// CHECK-SAME: window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1},
|
||||||
|
// CHECK-SAME: to_apply=%[[MAX_COMPUTATION]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> {
|
||||||
|
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32>
|
||||||
|
return %0 : tensor<1x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[2] parameter(0)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[1,2] reshape(f32[2] %[[ARG0]])
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
|
||||||
|
%result = "xla_hlo.reverse"(%arg0) {
|
||||||
|
dimensions = dense<[1,2]> : tensor<2xi64>
|
||||||
|
} : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32>
|
||||||
|
return %result : tensor<10x11x12x13xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[10,11,12,13] parameter(0)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[10,11,12,13] reverse(f32[10,11,12,13] %[[ARG0]]), dimensions={1,2}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main() -> tensor<2x3x5xf32> {
|
||||||
|
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
%1 = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
%2 = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
|
||||||
|
%3 = "xla_hlo.rng_uniform"(%0, %1, %2) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
|
||||||
|
return %3 : tensor<2x3x5xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK-DAG: %[[A:.*]] = f32[] constant(0)
|
||||||
|
// CHECK-DAG: %[[B:.*]] = f32[] constant(1)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[A]], f32[] %[[B]]), distribution=rng_uniform
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> {
|
||||||
|
%0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
|
||||||
|
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
|
||||||
|
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
|
||||||
|
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
|
||||||
|
}) {
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
update_window_dims = dense<[1]> : tensor<1xi64>,
|
||||||
|
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
index_vector_dim = 1 : i64
|
||||||
|
},
|
||||||
|
indices_are_sorted = true,
|
||||||
|
unique_indices = true
|
||||||
|
} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32>
|
||||||
|
return %0 : tensor<200x100x300xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: [[COMPUTATION:%.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: [[VAL_1:%.*]] = f32[200,100,300] parameter(0)
|
||||||
|
// CHECK: [[VAL_2:%.*]] = s32[10,2] parameter(1)
|
||||||
|
// CHECK: [[VAL_3:%.*]] = f32[10,300] parameter(2)
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[10,2] [[VAL_2]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||||
|
// CHECK: %[[ARG0:.*]] = pred[] parameter(0)
|
||||||
|
// CHECK: %[[COND:.*]] = pred[2,3] broadcast(pred[] %[[ARG0]]), dimensions={}
|
||||||
|
// CHECK: %[[ARG1:.*]] = s32[2,3] parameter(1)
|
||||||
|
// CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2)
|
||||||
|
|
||||||
|
// CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]])
|
||||||
|
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||||
|
return %0 : tensor<2x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
|
||||||
|
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
%1 = "xla_hlo.select_and_scatter"(%arg0, %arg1, %0) ( {
|
||||||
|
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
|
||||||
|
%2 = "xla_hlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"xla_hlo.return"(%2) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
|
||||||
|
%2 = xla_hlo.add %arg3, %arg4 : tensor<f32>
|
||||||
|
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||||
|
}) {
|
||||||
|
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
|
||||||
|
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) -> tensor<10x24x24x64xf32>
|
||||||
|
return %1 : tensor<10x24x24x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[SELECT_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] {
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE
|
||||||
|
|
||||||
|
// CHECK: %[[SCATTER_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] {
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[10,24,24,64] parameter(0)
|
||||||
|
// CHECK: %[[ARG1:.*]] = f32[10,12,12,64] parameter(1)
|
||||||
|
// CHECK: %[[INIT:.*]] = f32[] constant(0)
|
||||||
|
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[10,24,24,64]
|
||||||
|
// CHECK-SAME: select-and-scatter(f32[10,24,24,64] %[[ARG0]], f32[10,12,12,64] %[[ARG1]], f32[] %[[INIT]]),
|
||||||
|
// CHECK-SAME: window={size=1x2x2x1 stride=1x2x2x1},
|
||||||
|
// CHECK-SAME: select=%[[SELECT_COMPUTATION]], scatter=%[[SCATTER_COMPUTATION]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||||
|
%0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||||
|
return %0 : tensor<1x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0)
|
||||||
|
// CHECK-LABEL: ROOT
|
||||||
|
// CHECK-SAME: s32[1,2] slice(s32[3,4] [[ARG]]), slice={[1:2:1], [0:4:2]}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||||
|
// CHECK: [[ARG:%.*]] = s32[1,2,3,4] parameter(0)
|
||||||
|
|
||||||
|
// CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] [[ARG]]), dimensions={1,0,3,2}
|
||||||
|
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||||
|
return %0 : tensor<2x1x4x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<f32>, %arg1 : tensor<i32>) -> tuple<tensor<f32>, tensor<i32>> {
|
||||||
|
%result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||||
|
return %result : tuple<tensor<f32>, tensor<i32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[] parameter(0)
|
||||||
|
// CHECK: %[[ARG1:.*]] = s32[] parameter(1)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = (f32[], s32[]) tuple(f32[] %[[ARG0]], s32[] %[[ARG1]])
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) {
|
||||||
|
// CHECK: [[ARG_F32:%.*]] = f32[4] parameter(0)
|
||||||
|
// CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]])
|
||||||
|
%expm1 = "xla_hlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
|
||||||
|
// CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]])
|
||||||
|
%log1p = "xla_hlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
|
||||||
|
// CHECK: [[ARG_I32:%.*]] = s32[4] parameter(1)
|
||||||
|
// CHECK: [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]])
|
||||||
|
%not = "xla_hlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
|
||||||
|
// CHECK: [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]])
|
||||||
|
%popcnt = "xla_hlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
|
||||||
|
return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||||
|
// CHECK: [[VAL_1:%.*]] = pred[4] parameter(0)
|
||||||
|
// CHECK: [[VAL_2:%.*]] = pred[4] parameter(1)
|
||||||
|
%0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1>
|
||||||
|
// CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
|
||||||
|
return %0 : tensor<4xi1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: HloModule
|
||||||
|
func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||||
|
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||||
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
|
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] {
|
||||||
|
// CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT
|
||||||
|
|
||||||
|
// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) {
|
||||||
|
// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
|
@ -1,10 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
|
|
||||||
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
|
||||||
return %0 : tensor<i32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: ENTRY
|
|
||||||
// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
|
|
||||||
// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}
|
|
@ -1,10 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
|
||||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
|
||||||
return %0 : tensor<f32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: main
|
|
||||||
// CHECK: %[[ARG0:.*]] = (f32[], s32[]) parameter(0)
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0
|
|
@ -1,11 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main() -> tensor<1x10xf32> {
|
|
||||||
%result = "xla_hlo.iota"() {
|
|
||||||
iota_dimension = 1 : i64
|
|
||||||
} : () -> tensor<1x10xf32>
|
|
||||||
return %result : tensor<1x10xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL:main
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = f32[1,10] iota(), iota_dimension=1
|
|
@ -1,12 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
|
|
||||||
%0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
|
|
||||||
return %0 : tensor<13x19xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: ENTRY
|
|
||||||
// CHECK: [[ARG:%.*]] = f32[4,6] parameter(0)
|
|
||||||
// CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1)
|
|
||||||
// CHECK-LABEL: ROOT
|
|
||||||
// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1
|
|
@ -1,24 +0,0 @@
|
|||||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
|
||||||
|
|
||||||
func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
|
|
||||||
%result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
|
|
||||||
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
|
|
||||||
%fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
|
||||||
%imax = "xla_hlo.max"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
|
||||||
"xla_hlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
|
|
||||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
|
|
||||||
return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK: %[[REGION:region_[0-9]+]]
|
|
||||||
// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[])
|
|
||||||
// CHECK: %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]])
|
|
||||||
// CHECK: %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]])
|
|
||||||
// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]])
|
|
||||||
|
|
||||||
// CHECK: ENTRY %main
|
|
||||||
// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG0:.*]]: s32[1,10], [[ARG0:.*]]: f32[], [[ARG0:.*]]: s32[]) -> (f32[1], s32[1])
|
|
||||||
// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %Arg_0.1, s32[1,10] %Arg_1.2, f32[] %Arg_2.3, s32[] %Arg_3.4), dimensions={1}, to_apply=%[[REGION]]
|
|
||||||
// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0
|
|
||||||
// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]])
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user