Merge branch 'master' into fix_minimum_maximum

This commit is contained in:
Elena Zhelezina 2019-12-04 10:11:32 +00:00 committed by GitHub
commit 811c2a08ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
732 changed files with 16571 additions and 8354 deletions

View File

@ -100,9 +100,9 @@ build --apple_platform_type=macos
# iOS configs for each architecture and the fat binary builds.
build:ios --apple_platform_type=ios
build:ios --apple_bitcode=embedded --copt=-fembed-bitcode
build:ios --copt=-Wno-c++11-narrowing
build:ios_armv7 --config=ios
build:ios_armv7 --cpu=ios_armv7
build:ios_armv7 --copt -Wno-c++11-narrowing
build:ios_arm64 --config=ios
build:ios_arm64 --cpu=ios_arm64
build:ios_i386 --config=ios
@ -111,7 +111,6 @@ build:ios_x86_64 --config=ios
build:ios_x86_64 --cpu=ios_x86_64
build:ios_fat --config=ios
build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64
build:ios_fat --copt -Wno-c++11-narrowing
# Config to use a mostly-static build and disable modular op registration
# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
@ -202,18 +201,25 @@ build --define=allow_oversize_protos=true
build --spawn_strategy=standalone
build -c opt
# By default, build TF in C++ 14 mode.
build --cxxopt=-std=c++14
build --host_cxxopt=-std=c++14
# Make Bazel print out all options from rc files.
build --announce_rc
# Other build flags.
build --define=grpc_no_ares=true
# Prevent regression of https://github.com/bazelbuild/bazel/issues/7362
build --incompatible_remove_legacy_whole_archive
# See https://github.com/bazelbuild/bazel/issues/7362 for information on what
# --incompatible_remove_legacy_whole_archive flag does.
# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate
# Tensorflow to the default, however test coverage wasn't enough to catch the
# errors.
# There is ongoing work on Bazel team's side to provide support for transitive
# shared libraries. As part of migrating to transitive shared libraries, we
# hope to provide a better mechanism for control over symbol exporting, and
# then tackle this issue again.
#
# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library
# archives in -whole_archive -no_whole_archive.
build --noincompatible_remove_legacy_whole_archive
# Modular TF build options
build:dynamic_kernels --define=dynamic_loaded_kernels=true
@ -224,13 +230,55 @@ build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
# Default paths for TF_SYSTEM_LIBS
build --define=PREFIX=/usr
build --define=LIBDIR=$(PREFIX)/lib
build --define=INCLUDEDIR=$(PREFIX)/include
# Enable using platform specific build settings
build --enable_platform_specific_config
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
build --copt=-w
build:linux --copt=-w
build:macos --copt=-w
build:windows --copt=/w
# Default paths for TF_SYSTEM_LIBS
build:linux --define=PREFIX=/usr
build:linux --define=LIBDIR=$(PREFIX)/lib
build:linux --define=INCLUDEDIR=$(PREFIX)/include
build:macos --define=PREFIX=/usr
build:macos --define=LIBDIR=$(PREFIX)/lib
build:macos --define=INCLUDEDIR=$(PREFIX)/include
# TF_SYSTEM_LIBS do not work on windows.
# By default, build TF in C++ 14 mode.
build:linux --cxxopt=-std=c++14
build:linux --host_cxxopt=-std=c++14
build:macos --cxxopt=-std=c++14
build:macos --host_cxxopt=-std=c++14
build:windows --cxxopt=/std:c++14
build:windows --host_cxxopt=/std:c++14
# On windows, we still link everything into a single DLL.
build:windows --config=monolithic
# Make sure to include as little of windows.h as possible
build:windows --copt=-DWIN32_LEAN_AND_MEAN
build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
build:windows --copt=-DNOGDI
build:windows --host_copt=-DNOGDI
# Misc build options we need for windows.
build:windows --linkopt=/DEBUG
build:windows --host_linkopt=/DEBUG
build:windows --linkopt=/OPT:REF
build:windows --host_linkopt=/OPT:REF
build:windows --linkopt=/OPT:ICF
build:windows --host_linkopt=/OPT:ICF
build:windows --experimental_strict_action_env=true
build:windows --incompatible_windows_native_test_wrapper
# Verbose failure logs when something goes wrong
build:windows --verbose_failures
# On windows, we never cross compile
build:windows --distinct_host_configuration=false
# Suppress all warning messages.
build:short_logs --output_filter=DONT_MATCH_ANYTHING
@ -335,20 +383,6 @@ build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
# Misc build options we need for windows
build:rbe_win --copt=-DWIN32_LEAN_AND_MEAN
build:rbe_win --host_copt=-DWIN32_LEAN_AND_MEAN
build:rbe_win --copt=-DNOGDI
build:rbe_win --host_copt=-DNOGDI
build:rbe_win --linkopt=/DEBUG
build:rbe_win --host_linkopt=/DEBUG
build:rbe_win --linkopt=/OPT:REF
build:rbe_win --host_linkopt=/OPT:REF
build:rbe_win --linkopt=/OPT:ICF
build:rbe_win --host_linkopt=/OPT:ICF
build:rbe_win --config=monolithic
build:rbe_win --experimental_strict_action_env=true
build:rbe_win --incompatible_windows_native_test_wrapper
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
build:rbe_win --define=override_eigen_strong_inline=true

View File

@ -89,7 +89,7 @@ swift_rules_dependencies()
# files, in case the parsing of those build files depends on the bazel
# version we require here.
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
check_bazel_version_at_least("0.19.0")
check_bazel_version_at_least("1.0.0")
load("//third_party/android:android_configure.bzl", "android_configure")
android_configure(name="local_config_android")

View File

@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '0.27.1'
_TF_MIN_BAZEL_VERSION = '1.0.0'
_TF_MAX_BAZEL_VERSION = '1.1.0'
NCCL_LIB_PATHS = [
@ -1232,20 +1232,6 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
def set_windows_build_flags(environ_cp):
"""Set Windows specific build options."""
# The non-monolithic build is not supported yet
write_to_bazelrc('build --config monolithic')
# Suppress warning messages
write_to_bazelrc('build --copt=-w --host_copt=-w')
# Fix winsock2.h conflicts
write_to_bazelrc(
'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN '
'--copt=-DNOGDI --host_copt=-DNOGDI')
# Output more verbose information when something goes wrong
write_to_bazelrc('build --verbose_failures')
# The host and target platforms are the same in Windows build. So we don't
# have to distinct them. This avoids building the same targets twice.
write_to_bazelrc('build --distinct_host_configuration=false')
if is_reduced_optimize_huge_functions_available(environ_cp):
write_to_bazelrc(
'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions'

View File

@ -2,10 +2,7 @@
# TensorFlow is a computational framework, primarily for use in machine
# learning applications.
load("//tensorflow:tensorflow.bzl", "VERSION")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl")
load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_additional_binary_deps",
@ -450,6 +447,7 @@ config_setting(
package_group(
name = "internal",
packages = [
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",

View File

@ -119,11 +119,11 @@ def _running_from_pip_package():
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
if _running_from_pip_package():
for s in _site_packages_dirs:
for _s in _site_packages_dirs:
# TODO(gunan): Add sanity checks to loaded modules here.
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
if _fi.file_exists(plugin_dir):
_ll.load_library(plugin_dir)
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _fi.file_exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# Add module aliases
if hasattr(_current_module, 'keras'):
@ -136,3 +136,5 @@ if hasattr(_current_module, 'keras'):
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)
# pylint: enable=undefined-variable
# __all__ PLACEHOLDER

View File

@ -132,9 +132,10 @@ def _running_from_pip_package():
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
if _running_from_pip_package():
for s in _site_packages_dirs:
for _s in _site_packages_dirs:
# TODO(gunan): Add sanity checks to loaded modules here.
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
if _fi.file_exists(plugin_dir):
_ll.load_library(plugin_dir)
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _fi.file_exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# __all__ PLACEHOLDER

View File

@ -196,6 +196,12 @@ cc_library(
}),
)
cc_library(
name = "tf_status_headers",
hdrs = ["tf_status.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "tf_file_statistics",
hdrs = ["tf_file_statistics.h"],

View File

@ -231,7 +231,14 @@ cc_library(
srcs = ["shape_inference_helpers.cc"],
hdrs = ["shape_inference_helpers.h"],
visibility = [":friends"],
deps = ["//tensorflow/core:graph"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:graph",
],
}),
)
# Internal targets below this point.

View File

@ -880,11 +880,8 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
// Parses output_arrays_order from command line option.
absl::flat_hash_set<std::string> output_set;
std::vector<std::string> output_arrays_order;
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &output_set,
&output_arrays_order)
.ok()) {
std::vector<std::string> outputs;
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) {
return emitError(loc, "parsing output array info failed ")
<< output_arrays_string,
nullptr;
@ -892,7 +889,7 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
return tflite::FlatBufferToMlir(
absl::string_view(input->getBufferStart(), input->getBufferSize()),
context, loc, output_arrays_order);
context, loc, outputs);
}
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(

View File

@ -384,7 +384,7 @@ class Translator {
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
// Returns opcode index for op identified by the op_name, if already
// available. Otherwise, creates a new OperactorCode using the given `builtin`
// available. Otherwise, creates a new OperatorCode using the given `builtin`
// operator and associates it with `op_name`.
uint32_t GetOpcodeIndex(const std::string& op_name,
tflite::BuiltinOperator builtin);

View File

@ -720,7 +720,8 @@ static LogicalResult Verify(PackOp op) {
for (Value *operand : op.getOperands()) {
auto other_type = operand->getType().cast<ShapedType>();
if (input_type != other_type)
return op.emitOpError("operands should be of the same type");
return op.emitOpError("operands should be of the same type. got ")
<< input_type << ", " << other_type;
}
return success();
@ -857,8 +858,8 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//
// => Value [5, 8, 9]
// TODO(b/133341698): Move to tablegen when variadic is supported.
struct RemoveRedunantUnpackPack : public RewritePattern {
explicit RemoveRedunantUnpackPack(MLIRContext *context)
struct RemoveRedundantUnpackPack : public RewritePattern {
explicit RemoveRedundantUnpackPack(MLIRContext *context)
: RewritePattern(PackOp::getOperationName(), 2, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
@ -896,7 +897,7 @@ struct RemoveRedunantUnpackPack : public RewritePattern {
void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RemoveRedunantUnpackPack>(context);
results.insert<RemoveRedundantUnpackPack>(context);
}
//===----------------------------------------------------------------------===//
@ -1041,7 +1042,7 @@ struct DropFakeQuant : public RewritePattern {
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
// Replace the matched FakeQuantOp by its primiary operand.
// Replace the matched FakeQuantOp by its primary operand.
rewriter.replaceOp(op, op->getOperand(0));
}
};

View File

@ -32,7 +32,7 @@ def TFL_Dialect : Dialect {
Invariants:
* All values are of Tensor type (in particular, scalars are
represented using zero-dimentional tensors);
represented using zero-dimensional tensors);
}];
let cppNamespace = "TFL";
@ -603,7 +603,7 @@ def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
def TFL_FCWO_Shuffled4x16i8 : StrEnumAttrCase<"SHUFFLED4x16INT8">;
def TFL_FullyConnectedOptionsWeightFormatAttr :
StrEnumAttr<"FullyConectedOptionsWeightsFormat",
StrEnumAttr<"FullyConnectedOptionsWeightsFormat",
"fully connected options weights format", [
TFL_FCWO_Default, TFL_FCWO_Shuffled4x16i8
]>;
@ -1873,9 +1873,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
x -> max(0, x)
}];
let arguments = (ins AnyTensor:$x);
let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
let results = (outs AnyTensor:$y);
let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
}
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
@ -1888,9 +1888,24 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
x -> max(0, min(6, x))
}];
let arguments = (ins AnyTensor:$x);
let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
let results = (outs AnyTensor:$y);
let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
}
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale]> {
let summary = "Relu1 operator";
let description = [{
Element-wise Relu1 operator
x -> max(-1, min(1, x))
}];
let arguments = (ins TensorOf<[F32, QUI8, I8]>:$x);
let results = (outs TensorOf<[F32, QUI8, I8]>:$y);
}
def TFL_ReshapeOp: TFL_Op<"reshape", [

View File

@ -237,8 +237,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// Parse output arrays.
std::vector<string> output_arrays(model_flags.output_arrays().begin(),
model_flags.output_arrays().end());
TF_RETURN_IF_ERROR(tensorflow::ParseOutputArrayInfo(
output_arrays, &specs.output_arrays, &specs.output_arrays_order));
TF_RETURN_IF_ERROR(
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
// Other flags.
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();

View File

@ -82,7 +82,7 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
}
// A method to retrive the name for the given op.
// A method to retrieve the name for the given op.
OperationToName op_to_name_;
// We split the normal names and regex names, since the former can use hash

View File

@ -102,7 +102,7 @@ class AffineOpCoefficient<int dim, int index> : NativeOpTrait<
StrJoinInt<[dim, index]>.result,
">::Impl")>;
// Specify this trait if the op doesn't have quantizable ouput. We shouldn't
// Specify this trait if the op doesn't have quantizable output. We shouldn't
// apply quantization on this op.
def NoQuantizableResult : NativeOpTrait<"quant::NoQuantizableResult">;

View File

@ -54,7 +54,7 @@ class SameOperandsAndResultsScale
// OpTrait::quant::FixedResultUniformScale<
// 8, -128, 390625, -8, 0, 255, false>::Impl> {
//
// TODO(fengliuai): create a better way to epxress floating point scale in the
// TODO(fengliuai): create a better way to express floating point scale in the
// template argument list.
template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp,
int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign>

View File

@ -133,10 +133,10 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
// quantization parameters are annotated by the Q/DQ op pairs. Each
// matched pattern are rewritten by its quantized alternatives.
//
// The concret pattern, extends from this base pattern, can specify whether it
// The concrete pattern, extends from this base pattern, can specify whether it
// allows "hybrid" operands or results. These "hybrid" operands and results
// don't have quantization parameters propagated to, so will be in float in the
// quantized results. The concret pattern should define the following two
// quantized results. The concrete pattern should define the following two
// functions:
//
// bool AllowHybridOperand() const

View File

@ -114,8 +114,8 @@ func @fakequant_notdropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
// -----
// CHECK-LABEL: @RemoveRedunantUnpackPack
func @RemoveRedunantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
// CHECK-LABEL: @RemoveRedundantUnpackPack
func @RemoveRedundantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
%0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
%1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>)
return %1: tensor<2x5xf32>
@ -125,8 +125,8 @@ func @RemoveRedunantUnpackPack(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
// -----
// CHECK-LABEL: @RemoveRedunantPack
func @RemoveRedunantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5xf32>) {
// CHECK-LABEL: @RemoveRedundantPack
func @RemoveRedundantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5xf32>) {
%0:2 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 2 : i32} : (tensor<2x5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
%1 = "tfl.pack"(%0#0, %0#1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5xf32>, tensor<5xf32>) -> (tensor<2x5xf32>)
return %1, %0#0: tensor<2x5xf32>, tensor<5xf32>

View File

@ -106,8 +106,8 @@ func @extractStackInputOutputOphint() {
// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32>
// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32>
// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>)
// CHECK-DAG: %[[OUPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK-DAG: %[[OUPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK-DAG: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
// CHECK-DAG: %[[OUTPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32>
%1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32>

View File

@ -1,22 +1,22 @@
// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s --dump-input-on-failure
func @addRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%1 = "tf.Add"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
%3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
%4 = "tf.Add"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
%6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%7 = "tf.Relu6"(%6) : (tensor<1xi32>) -> tensor<1xi32>
return %7: tensor<1xi32>
func @addRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "tf.Add"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
%3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
%4 = "tf.Add"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
%6 = "tfl.add"(%5, %3) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%7 = "tf.Relu6"(%6) : (tensor<1xf32>) -> tensor<1xf32>
return %7: tensor<1xf32>
// CHECK-LABEL: addRelu
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xi32>
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
// CHECK: %1 = tfl.add %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: %3 = tfl.add %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
// CHECK: %4 = tfl.add %3, %2 {fused_activation_function = "RELU6"} : tensor<1xf32>
// CHECK: return
}
@ -244,32 +244,32 @@ func @zeros_like(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK: "tfl.zeros_like"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @divRelu(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%1 = "tf.Div"(%arg0, %0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%2 = "tf.Relu"(%1) : (tensor<1xi32>) -> tensor<1xi32>
%3 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
%4 = "tf.Div"(%3, %2) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%5 = "tf.Relu6"(%4) : (tensor<1xi32>) -> tensor<1xi32>
return %5: tensor<1xi32>
func @divRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "tf.Div"(%arg0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%2 = "tf.Relu"(%1) : (tensor<1xf32>) -> tensor<1xf32>
%3 = "tf.Relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
%4 = "tf.Div"(%3, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%5 = "tf.Relu6"(%4) : (tensor<1xf32>) -> tensor<1xf32>
return %5: tensor<1xf32>
// CHECK-LABEL: divRelu
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32>
// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xi32>
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xi32>
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xf32>
// CHECK: %1 = tfl.div %arg0, %0 {fused_activation_function = "RELU"} : tensor<1xf32>
// CHECK: %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: %3 = tfl.div %2, %1 {fused_activation_function = "RELU6"} : tensor<1xf32>
// CHECK: return
}
func @squaredDifferenceRelu(tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> {
^bb0(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>):
%0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%1 = "tf.Relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
return %1: tensor<1xi32>
func @squaredDifferenceRelu(tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> {
^bb0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>):
%0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "tf.Relu6"(%0) : (tensor<1xf32>) -> tensor<1xf32>
return %1: tensor<1xf32>
// CHECK-LABEL: squaredDifferenceRelu
// CHECK: tfl.squared_difference %arg0, %arg1 : tensor<1xi32>
// CHECK: %1 = "tfl.relu6"(%0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: tfl.squared_difference %arg0, %arg1 : tensor<1xf32>
// CHECK: %1 = "tfl.relu6"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: return
}

View File

@ -434,7 +434,7 @@ func @testEluI32(%arg0: tensor<? x i32>) -> tensor<? x i32> {
// -----
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
func @testFusedActivationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
// CHECK: "NONE"
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4xi32>
// CHECK: "RELU"
@ -452,7 +452,7 @@ func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -
// -----
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
func @testFusedActivationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// expected-error @+1 {{attribute 'fused_activation_function' failed to satisfy constraint: fused activation enum}}
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "Relu6"} : tensor<4xi32>
return %0: tensor<4xi32>
@ -1047,7 +1047,7 @@ func @testConcatInvalidOperandDimSize(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3x
// -----
func @testConcatInvalidOperandDimSizeComaredToPrevInput(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3xi32>) -> tensor<?x?xi32> {
func @testConcatInvalidOperandDimSizeComparedToPrevInput(%arg0: tensor<1x2xi32>, %arg1: tensor<1x3xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{'tfl.concatenation' op dimension size of dimension #1 of operand #1 must be equal to dimension size of dimension #1 of operand #0, expected 2, got 3}}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x3xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>

View File

@ -278,7 +278,7 @@ func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x
%cst2 = constant dense<3.0> : tensor<112x2xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// We cannot fuse this tfl.mul into the preceding conv op becuase %cst2 is not broadcast-compatible to %cst0.
// We cannot fuse this tfl.mul into the preceding conv op because %cst2 is not broadcast-compatible to %cst0.
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
return %1 : tensor<1x112x112x2xf32>
@ -600,3 +600,25 @@ func @squeezeToReshape(%arg0: tensor<1x1x2xf32>) -> tensor<2xf32> {
// CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor<1x1x2xf32>, tensor<1xi32>) -> tensor<2xf32>
// CHECK: return %[[RESULT]]
}
// CHECK-LABEL: Relu1
func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%cst = constant dense<-1.0> : tensor<f32>
%cst1 = constant dense<1.0> : tensor<f32>
%0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
%1 = "tfl.minimum"(%0, %cst1) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
}
// CHECK-LABEL: Relu1_2
func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%cst = constant dense<-1.0> : tensor<f32>
%cst1 = constant dense<1.0> : tensor<f32>
%0 = "tfl.minimum"(%arg0, %cst1) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
%1 = "tfl.maximum"(%0, %cst) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
}

View File

@ -30,7 +30,7 @@ opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
opt<bool> use_splatted_constant(
"use-splatted-constant",
llvm::cl::desc(
"Replace constants with randonmly generated splatted tensors"),
"Replace constants with randomly generated splatted tensors"),
llvm::cl::init(false), llvm::cl::Hidden);
// NOLINTNEXTLINE
opt<bool> input_mlir(

View File

@ -282,7 +282,7 @@ struct OphintCompositeOp {
// Since we have different aggregation strategies, e.g., "first", "last",
// "stack". We don't somehow aggregated to get the outputs for the funcOp.
// This function is simply compute the RankedTensorType (shape & element type)
std::map<int, Type> GetAggregatedOuputTypes(OpBuilder* builder) {
std::map<int, Type> GetAggregatedOutputTypes(OpBuilder* builder) {
std::map<int, Type> aggregated_output_types;
for (const auto& kv : outputs) {
const AggregatedOperand& operand = kv.second;
@ -387,11 +387,12 @@ struct OphintCompositeOp {
// inputs/outputs indicate edges) Assume the graph is acyclic. The preprocess
// does the following:
// Compute each operations's in-degress (how many input nodes they're taken)
// Get all consumer operations for every operations. (operation_to_ouputs)
// Get all consumer operations for every operations. (operation_to_outputs)
// Get the init_queue (those operations will be processed first).
void PreprocessTopoSortGraph(
Block* block, std::queue<Operation*>* init_queue,
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>>* operation_to_ouputs,
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>>*
operation_to_outputs,
llvm::DenseMap<Operation*, int>* operation_to_in_degrees) {
for (auto& op : *block) {
if (&op == block->getTerminator()) continue;
@ -412,9 +413,9 @@ void PreprocessTopoSortGraph(
}
operation_to_in_degrees->try_emplace(&op, input_ops.size());
for (auto* input_op : input_ops) {
auto preceeding_op_it = operation_to_ouputs->find(input_op);
if (preceeding_op_it == operation_to_ouputs->end()) {
auto result = operation_to_ouputs->try_emplace(
auto preceeding_op_it = operation_to_outputs->find(input_op);
if (preceeding_op_it == operation_to_outputs->end()) {
auto result = operation_to_outputs->try_emplace(
input_op, llvm::DenseSet<Operation*>());
preceeding_op_it = result.first;
}
@ -442,19 +443,19 @@ bool IsSideEffectOp(Operation* op) {
// Also assume the block has no arguments.
LogicalResult TopoSortOperations(OpBuilder* builder) {
std::queue<Operation*> init_queue;
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>> operation_to_ouputs;
llvm::DenseMap<Operation*, llvm::DenseSet<Operation*>> operation_to_outputs;
llvm::DenseMap<Operation*, int> operation_to_in_degrees;
std::vector<Operation*> sorted_ops;
PreprocessTopoSortGraph(builder->getBlock(), &init_queue,
&operation_to_ouputs, &operation_to_in_degrees);
&operation_to_outputs, &operation_to_in_degrees);
while (!init_queue.empty()) {
Operation* current_op = init_queue.front();
init_queue.pop();
sorted_ops.push_back(current_op);
auto current_op_to_output_it = operation_to_ouputs.find(current_op);
if (current_op_to_output_it == operation_to_ouputs.end()) {
auto current_op_to_output_it = operation_to_outputs.find(current_op);
if (current_op_to_output_it == operation_to_outputs.end()) {
continue;
}
for (Operation* output_op : current_op_to_output_it->second) {
@ -467,7 +468,7 @@ LogicalResult TopoSortOperations(OpBuilder* builder) {
operation_to_in_degrees.erase(output_op_it);
}
}
operation_to_ouputs.erase(current_op_to_output_it);
operation_to_outputs.erase(current_op_to_output_it);
}
// Before we performs the sort. We need to make sure we didn't mess the
@ -629,11 +630,11 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
// Step 4, get aggregated output types.
const std::map<int, Type>& aggregated_output_types =
ophint_composite_op.GetAggregatedOuputTypes(builder);
ophint_composite_op.GetAggregatedOutputTypes(builder);
// Step 5, create & place the fused op and rewire the inputs.
// Here we use a funcOp to represent the fused op. This "funcOp" will be
// coonverted to other ops (like UnidirectionalSequenceRNNOp) in the
// converted to other ops (like UnidirectionalSequenceRNNOp) in the
// legalization phase.
Operation* inserted_before_op = ophint_composite_op.GetFirstOutputOp();
Operation* fused_op = BuildFusedFuncOp(

View File

@ -191,10 +191,10 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
return success();
}
LogicalResult ConvertTfLiteFusedOpIfAvaiable(StringRef func_name,
FuncOp composite_func_op,
CallOp call_op,
OpBuilder* builder) {
LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name,
FuncOp composite_func_op,
CallOp call_op,
OpBuilder* builder) {
Operation* fused_op = nullptr;
if (func_name == kUnidirectionalSequenceRnn) {
// TODO(renjieliu): Validate the func op inputs.
@ -243,8 +243,8 @@ LogicalResult ConvertCallOps(llvm::StringMap<FuncOp>* composite_func_ops,
StringRef func_name = composite_func_op.getAttr(kTfLiteFunctionName)
.cast<StringAttr>()
.getValue();
if (failed(ConvertTfLiteFusedOpIfAvaiable(func_name, composite_func_op,
call_op, &builder)))
if (failed(ConvertTfLiteFusedOpIfAvailable(func_name, composite_func_op,
call_op, &builder)))
return failure();
composite_func_ops->erase(it);

View File

@ -140,6 +140,8 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
return ExpandTo4DForConvImpl(a, true);
}
// Returns shape of a ranked tensor.
// Precondition: output_val's is ranked tensor.
DenseElementsAttr GetShape(Value *output_val) {
auto output_type = output_val->getType().cast<RankedTensorType>();
auto shape_vector = output_type.getShape();

View File

@ -267,9 +267,27 @@ multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
in defm : FuseTileBroadcastIntoFollowingBinary<BroadcastingOp>;
// Returns shape of a ranked tensor.
// if called without a ranked tensor it will fail.
def GetShape: NativeCodeCall<"GetShape($0)">;
def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
(TFL_ReshapeOp $input,
(ConstantOp (GetShape $squeeze_op))),
[(AnyStaticShapeTensor $squeeze_op)]>;
class ValueEquals<string val> : Constraint<CPred<
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
(ConstantOp $NegOne)),
(ConstantOp $One)),
(TFL_Relu1Op $input),
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
(ConstantOp $One)),
(ConstantOp $NegOne)),
(TFL_Relu1Op $input),
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;

View File

@ -68,11 +68,11 @@ class ConvertEmbeddedLookupFunc {
if (func_.getNumArguments() != 2) {
return func_.emitError()
<< "Invalid number of arguments in the embedding "
"matmal composite function";
"matmul composite function";
}
if (func_.getType().getNumResults() != 1) {
return func_.emitError() << "Invalid number of results in the embedding "
"matmal composite function";
"matmul composite function";
}
return success();
}

View File

@ -34,7 +34,7 @@ limitations under the License.
// NOLINTNEXTLINE
static llvm::cl::list<std::string> quantize_whitelist(
"tfl-test-quantize-whitelist", llvm::cl::value_desc("list"),
llvm::cl::desc("comma seprarated list of whitelisted functions to be "
llvm::cl::desc("comma separated list of whitelisted functions to be "
"quantized. Only used in tests"),
llvm::cl::CommaSeparated);

View File

@ -400,7 +400,7 @@ class ConvertTFDepthwiseConv2dNative
}
};
// StridedSlice can have complicated atributes like begin_axis_mask,
// StridedSlice can have complicated attributes like begin_axis_mask,
// end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
// masks will complicate the strided_slice computation logic, we can simplify
// the logic by inserting a reshape op to pad the inputs so strided_slice can

View File

@ -247,7 +247,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
}
if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
// Input dimensions must be compatible for multipication.
// Input dimensions must be compatible for multiplication.
return this->matchFailure();
}

View File

@ -126,7 +126,7 @@ class ConvertLSTMCellSimpleToFusedLSTM {
Value* input2cell_;
Value* input2output_;
// reccurrent -> cifg
// recurrent -> cifg
Value* rec2input_;
Value* rec2forget_;
Value* rec2cell_;

View File

@ -28,7 +28,7 @@ config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm')
config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR'])
config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'],
'local_config_mlir')
# TODO(jpienaar): Replace with sufffices in build rule.
# TODO(jpienaar): Replace with suffices in build rule.
config.suffixes = ['.td', '.mlir', '.pbtxt']
mlir_tf_tools_dirs = [

View File

@ -217,6 +217,7 @@ cc_library(
"transforms/shape_inference.cc",
"transforms/shape_inference_pass.cc",
"transforms/sink_constant.cc",
"transforms/test_side_effect_analysis.cc",
"transforms/tpu_cluster_formation.cc",
"transforms/tpu_merge_variables_with_execute.cc",
"transforms/tpu_rewrite_pass.cc",
@ -239,6 +240,7 @@ cc_library(
":error_util",
":export_tf_dialect_op",
":mangling_util",
":side_effect_analysis",
":tensorflow",
":tensorflow_optimize_inc_gen",
":tpu_rewrite_device_util",
@ -467,6 +469,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:types",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
@ -669,6 +672,7 @@ cc_library(
":import_utils",
":mangling_util",
":mlir_roundtrip_flags",
"//tensorflow/core:graph",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
@ -981,3 +985,19 @@ cc_library(
"@local_config_mlir//:Pass",
],
)
cc_library(
name = "side_effect_analysis",
srcs = ["analysis/side_effect_analysis.cc"],
hdrs = ["analysis/side_effect_analysis.h"],
deps = [
":tensorflow",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/core:framework",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:Support",
],
)

View File

@ -0,0 +1,374 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
#include <cstdint>
#include <initializer_list>
#include "absl/strings/str_cat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/core/framework/resource_mgr.h"
namespace mlir {
namespace TF {
namespace {
constexpr int64_t kUnknownResourceId = -1;
// Returns if a VarHandleOp is anonymous, which means it always creates a new
// variable.
bool IsResourceHandleAnonymous(TF::VarHandleOp handle) {
return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME;
}
// Returns a string unique identifier for a non-anonymous VarHandleOp.
std::string GetVarHandleStringId(TF::VarHandleOp handle) {
auto device = handle.getAttrOfType<StringAttr>("device");
return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(),
"/", device ? device.getValue().str() : std::string(""));
}
// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always
// creates a new ID; otherwise, tries to reuse the existing ID for the
// referenced variable if it exists, or creates a new one if not.
int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id,
llvm::StringMap<int64_t>* name_id_map) {
// Always create a new ID for anonymous handle.
if (IsResourceHandleAnonymous(handle)) return (*next_id)++;
auto name = GetVarHandleStringId(handle);
auto emplace_res = name_id_map->try_emplace(name, *next_id);
// New ID created, increment next_id.
if (emplace_res.second) ++(*next_id);
return emplace_res.first->second;
}
} // namespace
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) {
auto func_op = llvm::dyn_cast<FuncOp>(op);
if (!func_op) return;
AnalyzeFunction(func_op);
}
void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
// This function populates resource_value_to_ids_.
//
// TODO(yuanzx): Pass variable aliasing information to functions so we can
// properly resolve aliasing arguments.
//
// Before having that, we assume function arguments do not alias each other.
int64_t next_unique_id = 0;
for (auto arg : func_op.getArguments()) {
if (!mlir::getElementTypeOrSelf(arg->getType()).isa<TF::ResourceType>())
continue;
resource_value_to_ids_[arg].insert(next_unique_id++);
}
llvm::StringMap<int64_t> var_handle_name_id_map;
auto forward_input_to_output = [&](Value* operand, Value* result) {
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
return;
auto operand_it = resource_value_to_ids_.find(operand);
assert(operand_it != resource_value_to_ids_.end() &&
"A resource-type output does not have the corresponding "
"resource-type input.");
resource_value_to_ids_[result].insert(operand_it->getSecond().begin(),
operand_it->getSecond().end());
};
// TODO(yuanzx): Consider control-flow ops.
func_op.walk([&](Operation* op) {
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
resource_value_to_ids_[var_handle.resource()].insert(
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
&var_handle_name_id_map));
} else if (llvm::isa<TF::IdentityNOp>(op) ||
llvm::isa<TF::IdentityOp>(op)) {
for (auto operand_and_result :
llvm::zip(op->getOperands(), op->getResults())) {
forward_input_to_output(std::get<0>(operand_and_result),
std::get<1>(operand_and_result));
}
} else {
for (auto result : op->getResults()) {
if (!mlir::getElementTypeOrSelf(result->getType())
.isa<TF::ResourceType>())
continue;
resource_value_to_ids_[result].insert(kUnknownResourceId);
}
}
});
}
bool ResourceAliasAnalysis::IsUnknownResource(const Value* resource) const {
auto it = resource_value_to_ids_.find(resource);
assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
// The set is sorted so we only need to check the first element since
// kUnknownResourceId < 0.
static_assert(kUnknownResourceId < 0,
"kUnknownResourceId should be negative");
return *it->getSecond().begin() == kUnknownResourceId;
}
const llvm::SmallSet<int64_t, 8>& ResourceAliasAnalysis::GetResourceUniqueIds(
const Value* resource) const {
auto it = resource_value_to_ids_.find(resource);
assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
return it->getSecond();
}
namespace {
// Returns a set that contains only kUnknownResourceId.
llvm::SmallDenseSet<int64_t, 8> UnknownResourceSet() {
llvm::SmallDenseSet<int64_t, 8> unknown_set;
unknown_set.insert(kUnknownResourceId);
return unknown_set;
}
// Returns all resources that could be accessed by op, or UnknownResourceSet()
// if we cannot find all of them.
llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
Operation* op, const ResourceAliasAnalysis& alias_analysis) {
llvm::SmallDenseSet<int64_t, 8> resources;
for (auto operand : op->getOperands()) {
if (!mlir::getElementTypeOrSelf(operand->getType()).isa<TF::ResourceType>())
continue;
if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet();
const auto& ids = alias_analysis.GetResourceUniqueIds(operand);
resources.insert(ids.begin(), ids.end());
}
for (auto result : op->getResults()) {
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
continue;
if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet();
const auto& ids = alias_analysis.GetResourceUniqueIds(result);
resources.insert(ids.begin(), ids.end());
}
return resources;
}
// Returns an XlaResourceOpInfo (or nullptr if it does not exist) that specifies
// the resource access type of the op. It tells whether the op is read only,
// etc.
//
// TODO(yuanzx): Define this information in a different place. Currently we use
// tensorflow/compiler/tf2xla/resource_operation_table.h.
const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) {
auto op_name = op->getName().getStringRef().str();
if (op->getName().getDialect() !=
TF::TensorFlowDialect::getDialectNamespace()) {
return nullptr;
}
return tensorflow::GetResourceOpInfoForOp(
op->getName().getStringRef().split('.').second.str());
}
// Returns whether `op` accesses resources and it is known to be read-only.
bool OpIsReadOnly(Operation* op) {
auto resource_op_info = GetResourceInfoForOp(op);
return resource_op_info &&
resource_op_info->kind() == tensorflow::XlaResourceOpKind::kRead;
}
// Returns if `op` is a resource declaration.
bool OpIsDeclaration(Operation* op,
const ResourceAliasAnalysis& alias_analysis) {
// TODO(yuanzx): Add other types of resources.
return llvm::isa<TF::VarHandleOp>(op) ||
((llvm::isa<TF::IdentityNOp>(op) || llvm::isa<TF::IdentityOp>(op)) &&
!FindAccessedResources(op, alias_analysis).empty());
}
} // namespace
void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op,
bool read_only) {
if (resource_id == kUnknownResourceId) {
if (read_only) {
// New unknown read is not tracked by any known resource access.
for (auto& entry : per_resource_access_info_) {
entry.getSecond().tracked_last_unknown_read = false;
}
} else {
// Unknown write can clear all other tracked information, since it acts
// like a barrier.
per_resource_access_info_.clear();
}
}
auto& info = per_resource_access_info_[resource_id];
if (read_only) {
info.reads_since_last_write.push_back(op);
// Resource read must have carried control dependencies of unknown write.
info.tracked_last_unknown_write = true;
} else {
// Resource write must have carried control dependencies of unknown access.
info.tracked_last_unknown_write = true;
info.tracked_last_unknown_read = true;
info.last_write = op;
info.reads_since_last_write.clear();
}
}
void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id,
Operation* op,
bool read_only) {
auto it = per_resource_access_info_.find(resource_id);
if (it == per_resource_access_info_.end()) return;
const auto& access_info = it->getSecond();
auto& control_predecessors = control_predecessors_[op];
bool read_tracked = false;
if (!read_only) {
control_predecessors.insert(access_info.reads_since_last_write.begin(),
access_info.reads_since_last_write.end());
read_tracked = !access_info.reads_since_last_write.empty();
}
if (access_info.last_write && !read_tracked) {
control_predecessors.insert(access_info.last_write);
}
}
void SideEffectAnalysis::AnalyzeFunction(
FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) {
// This function populates control_predecessors_ and control_successors_ by
// walking through func_op's body, and tracking resource accesses in
// per_resource_access_info_.
// Returns whether an access to `resource` can skip control edges from
// prevoius accesses to unknown resources, due to that earlier accesses to
// `resource` already indirectly tracked previous accesses to uknown
// resources. `read_only` specifies the type of access of the current op being
// considered.
auto unknown_access_indirectly_tracked_by_resource = [&](int64_t resource,
bool read_only) {
auto it = per_resource_access_info_.find(resource);
if (it == per_resource_access_info_.end()) return false;
auto unknown_it = per_resource_access_info_.find(kUnknownResourceId);
const bool no_unknown_read =
unknown_it == per_resource_access_info_.end() ||
unknown_it->getSecond().reads_since_last_write.empty();
return read_only
? it->second.tracked_last_unknown_write
: it->second.tracked_last_unknown_write &&
(it->second.tracked_last_unknown_read || no_unknown_read);
};
func_op.walk([&](Operation* op) {
// We do not need explicit control edges for declaration ops.
if (OpIsDeclaration(op, alias_analysis)) return;
auto resource_op_info = GetResourceInfoForOp(op);
if (!resource_op_info && op->hasNoSideEffect()) return;
llvm::SmallDenseSet<int64_t, 8> resources =
resource_op_info ? FindAccessedResources(op, alias_analysis)
: UnknownResourceSet();
assert(!resources.empty());
const bool is_unknown = resources.count(kUnknownResourceId) > 0;
const bool read_only = OpIsReadOnly(op);
bool indirectly_tracked_unknown_access = false;
// First add edges from known resources.
if (is_unknown) {
for (auto& entry : per_resource_access_info_) {
if (entry.getFirst() == kUnknownResourceId) continue;
AddPredecessorsForAccess(entry.getFirst(), op, read_only);
indirectly_tracked_unknown_access |=
unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
read_only);
}
} else {
for (int64_t resource : resources) {
AddPredecessorsForAccess(resource, op, read_only);
indirectly_tracked_unknown_access |=
unknown_access_indirectly_tracked_by_resource(resource, read_only);
// Update access info for known resources.
TrackAccess(resource, op, read_only);
}
}
// If not indirectly tracked, add edges from the unknown resource.
if (!indirectly_tracked_unknown_access) {
AddPredecessorsForAccess(kUnknownResourceId, op, read_only);
}
if (is_unknown) {
// Update access info for unknown resource.
TrackAccess(kUnknownResourceId, op, read_only);
}
});
// Populate control_successors_ based on control_predecessors_.
for (auto& entry : control_predecessors_) {
auto op = entry.getFirst();
for (auto predecessor : entry.getSecond()) {
control_successors_[predecessor].insert(op);
}
}
}
llvm::SmallVector<Operation*, 8> SideEffectAnalysis::DirectControlPredecessors(
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
llvm::SmallVector<Operation*, 8> result;
auto it = control_predecessors_.find(op);
if (it == control_predecessors_.end()) return result;
result.reserve(it->getSecond().size());
for (auto predecessor : it->getSecond()) {
if (!filter || filter(predecessor)) result.push_back(predecessor);
}
llvm::sort(result,
[](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
return result;
}
llvm::SmallVector<Operation*, 8> SideEffectAnalysis::DirectControlSuccessors(
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
llvm::SmallVector<Operation*, 8> result;
auto it = control_successors_.find(op);
if (it == control_successors_.end()) return result;
result.reserve(it->getSecond().size());
for (auto successor : it->getSecond()) {
if (!filter || filter(successor)) result.push_back(successor);
}
llvm::sort(result,
[](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
return result;
}
SideEffectAnalysis::SideEffectAnalysis(Operation* op) {
auto func_op = llvm::dyn_cast<FuncOp>(op);
if (!func_op) return;
ResourceAliasAnalysis alias_analysis(op);
AnalyzeFunction(func_op, alias_analysis);
}
} // namespace TF
} // namespace mlir

View File

@ -0,0 +1,125 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
#include <cstdint>
#include <memory>
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/Region.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
namespace mlir {
namespace TF {
// An analysis that runs on a function and maps each resource-type value to a
// set of unique int64_t IDs representing the possible resources it could alias.
class ResourceAliasAnalysis {
public:
explicit ResourceAliasAnalysis(Operation* op);
~ResourceAliasAnalysis() = default;
ResourceAliasAnalysis(ResourceAliasAnalysis&&) = default;
// Returns if the analysis fails to resolve a resource-type value.
bool IsUnknownResource(const Value* resource) const;
// Returns the set unique IDs which `resource` could alias. Requires that
// IsUnknownResource(resource) == true.
const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds(
const Value* resource) const;
private:
ResourceAliasAnalysis() = default;
// Runs the analysis on `func_op` and populates resource_value_to_ids_.
void AnalyzeFunction(FuncOp func_op);
// Maps each resource-type value to a set of unique IDs that it could alias.
llvm::SmallDenseMap<const Value*, llvm::SmallSet<int64_t, 8>, 8>
resource_value_to_ids_;
};
// An analysis that runs on a function and infers the control predecessors and
// successors for each op, based on side-effects on known and unknown resources.
// Side-effecting ops on uknown resources are conservatively treated as
// interfering with all known resource op accesses. It distinguishes accesses
// based on whether they are read-only, and read-only ops do not interfer with
// each other.
class SideEffectAnalysis {
public:
explicit SideEffectAnalysis(Operation* op);
SideEffectAnalysis(SideEffectAnalysis&& other) = default;
~SideEffectAnalysis() = default;
// Returns a vector of ops that are direct control predecessors of `op`,
// sorted in program order. If `filter` is provided, only predecessors that
// pass the filter (returning true) will be included.
llvm::SmallVector<Operation*, 8> DirectControlPredecessors(
Operation* op,
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
// Returns a vector of ops that are direct control successors of `op`, sorted
// in program order. If `filter` is provided, only successors that pass the
// filter (returning true) will be included.
llvm::SmallVector<Operation*, 8> DirectControlSuccessors(
Operation* op,
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
private:
// Runs the analysis on `func_op` and populates control_predecessors_ and
// control_successors_.
void AnalyzeFunction(FuncOp func_op,
const ResourceAliasAnalysis& alias_analysis);
// Updates control_predecessors_ for `op` that is being visted, on the given
// `resource_id`.
void AddPredecessorsForAccess(int64_t resource_id, Operation* op,
bool read_only);
// Adds op's access to per_resource_access_info_.
void TrackAccess(int64_t resource_id, Operation* op, bool read_only);
// Maps from an op to its control predecessors.
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 8>, 8>
control_predecessors_;
// Maps from an op to its control successors.
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 8>, 8>
control_successors_;
// Internal per-resource data structure when we build the dependencies.
struct PerResourceAcessInfo {
// Last op that writes the resource before the current op being analyzed.
Operation* last_write = nullptr;
// Read ops since last_write before the current op being analyzed.
llvm::SmallVector<Operation*, 8> reads_since_last_write;
// Whether previous accesses of this resource already tracked last unknown
// read/write.
bool tracked_last_unknown_read = false;
bool tracked_last_unknown_write = false;
};
llvm::SmallDenseMap<int64_t, PerResourceAcessInfo, 8>
per_resource_access_info_;
};
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_

View File

@ -26,7 +26,7 @@ static DialectRegistration<TFControlFlow::TFControlFlowDialect>
tf_control_flow_ops;
static DialectRegistration<TF::TensorFlowDialect> tf_ops;
static DialectRegistration<tf_executor::TensorFlowExecutorDialect>
tf_excutor_dialect;
tf_executor_dialect;
static DialectRegistration<tf_device::TensorFlowDeviceDialect>
tf_device_dialect;
static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect>

View File

@ -26,11 +26,13 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/SMLoc.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir

View File

@ -576,6 +576,44 @@ endian orderings will give different results.
let hasCanonicalizer = 1;
}
def TF_BitwiseOrOp : TF_Op<"BitwiseOr", [Broadcastable, Commutative, NoSideEffect]>,
WithBroadcastableBinOpBuilder {
let summary = "Elementwise computes the bitwise OR of `x` and `y`.";
let description = [{
The result will have those bits set, that are set in `x`, `y` or both. The
computation is performed on the underlying representations of `x` and `y`.
For example:
```python
import tensorflow as tf
from tensorflow.python.ops import bitwise_ops
dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64,
tf.uint8, tf.uint16, tf.uint32, tf.uint64]
for dtype in dtype_list:
lhs = tf.constant([0, 5, 3, 14], dtype=dtype)
rhs = tf.constant([5, 0, 7, 11], dtype=dtype)
exp = tf.constant([5, 5, 7, 15], dtype=tf.float32)
res = bitwise_ops.bitwise_or(lhs, rhs)
tf.assert_equal(tf.cast(res, tf.float32), exp) # TRUE
```
}];
let arguments = (ins
TF_IntTensor:$x,
TF_IntTensor:$y
);
let results = (outs
TF_IntTensor:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect]> {
let summary = [{
Return the reduction indices for computing gradients of s0 op s1 with broadcast.
@ -1320,6 +1358,10 @@ Comparison with `numpy.einsum`:
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
let verifier = [{
return Verify(*this);
}];
}
def TF_EluOp : TF_Op<"Elu", [NoSideEffect, SameOperandsAndResultType]> {
@ -5726,6 +5768,8 @@ If two elements are equal, the lower-index element appears first.
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let verifier = [{ return Verify(*this); }];
}
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {

View File

@ -39,7 +39,7 @@ This dialect maps to TensorFlow operations.
Invariants:
* All values are of Tensor type (in particular, scalars are
represented using zero-dimentional tensors);
represented using zero-dimensional tensors);
TODO: Make invariants more structured so that we can reference them in ops.
}];

View File

@ -513,7 +513,7 @@ void ConstOp::build(Builder *builder, OperationState &result, Attribute value) {
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
value.isa<IntegerAttr>()) {
// All TensorFlow types must be tensor types. In the build() method,
// we want to provide more flexiblity by allowing attributes of scalar
// we want to provide more flexibility by allowing attributes of scalar
// types. But we need to wrap it up with ElementsAttr to construct
// valid TensorFlow constants.
type = RankedTensorType::get(/*shape=*/{}, value.getType());
@ -674,6 +674,21 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<DivWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
// EinsumOp
//===----------------------------------------------------------------------===//
// Verifies that,
// * Arity of the op is at most two.
//
// TODO(hinsu): Verify einsum equation attribute.
static LogicalResult Verify(EinsumOp op) {
if (op.N() > 2) {
return op.emitOpError("supports at most two operands");
}
return success();
}
//===----------------------------------------------------------------------===//
// EmptyTensorListOp
//===----------------------------------------------------------------------===//
@ -1683,6 +1698,21 @@ static LogicalResult Verify(TensorListStackOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// TopKV2Op
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TopKV2Op op) {
if (!HasRankAtLeast(op.input(), 1))
return op.emitOpError(
"requires input operand to have at least 1 dimension");
if (!IsOfRankOrUnranked(op.k(), 0))
return op.emitOpError("requires k operand to be 0D tensor");
return success();
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//

View File

@ -41,7 +41,7 @@ class TensorFlowDialect : public Dialect {
static StringRef getDialectNamespace() { return "tf"; }
// Gradient attribute ("tf.gradient") in the list of NamedAttibutes in a
// Gradient attribute ("tf.gradient") in the list of NamedAttributes in a
// function references to its gradient function. This attribute in TensorFlow
// Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
// string description of gradient attribute.

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir

View File

@ -101,6 +101,17 @@ func @testDifferentCastType(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64
// CHECK: return %0, %1
}
// CHECK-LABEL: testCompatibleCastType
func @testCompatibleCastType(%arg0: tensor<?xf32>) -> (tensor<10xf32>, tensor<10xf32>) {
%0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<?xf32>) -> tensor<10xf32>
%1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<?xf32>) -> tensor<10xf32>
return %0, %1: tensor<10xf32>, tensor<10xf32>
// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<?xf32>) -> tensor<10xf32>
// CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<?xf32>) -> tensor<10xf32>
// CHECK: return %0, %1
}
// CHECK-LABEL: testSameCastTypeAcrossBasicBlocks
func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> {
^bb0(%arg0: tensor<8x16x32x64xf32>):

View File

@ -232,12 +232,12 @@ module {
// -----
// Single device with non-continous instructions in original block.
// Single device with non-continuous instructions in original block.
module {
// CHECK-LABEL: func @noncontinoussinglecluster
// CHECK-LABEL: func @noncontinuoussinglecluster
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
func @noncontinoussinglecluster(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func @noncontinuoussinglecluster(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = tf_executor.graph {
%1:2 = tf_executor.island {

View File

@ -0,0 +1,237 @@
// RUN: tf-opt -split-input-file -tf-test-side-effect-analysis -verify-diagnostics %s | FileCheck %s --dump-input=fail
// Tests that the pass tracks control dependencies for reads/writes on the same
// resource.
// CHECK-LABEL: func @non_aliasing_reads_writes
func @non_aliasing_reads_writes(
// expected-remark@above {{ID: 13}}
// expected-remark@above {{Predecessors: {12}}}
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg2: tensor<32xf32>) -> (tensor<32xf32>) {
%graph = tf_executor.graph {
// expected-remark@above {{ID: 11}}
// expected-remark@above {{Predecessors: {10}}}
// expected-remark@above {{Successors: {12}}}
// CHECK: tf_executor.island
%island:2 = tf_executor.island {
// expected-remark@above {{ID: 9}}
// expected-remark@above {{Predecessors: {8}}}
// expected-remark@above {{Successors: {10}}}
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 0}}
// expected-remark@above {{Successors: {1}}}
"tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 1}}
// expected-remark@above {{Predecessors: {0}}}
// expected-remark@above {{Successors: {6}}}
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 2}}
// expected-remark@above {{Successors: {5}}}
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
// expected-remark@above {{ID: 3}}
%read2 = "tf.ReadVariableOp"(%var_handle) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 4}}
// expected-remark@above {{Successors: {8}}}
"tf.AssignVariableOp"(%arg1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 5}}
// expected-remark@above {{Predecessors: {2}}}
// expected-remark@above {{Successors: {8}}}
"tf.AssignVariableOp"(%arg0, %read2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 6}}
// expected-remark@above {{Predecessors: {1}}}
// expected-remark@above {{Successors: {7}}}
%read3 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 7}}
// expected-remark@above {{Predecessors: {6}}}
// expected-remark@above {{Successors: {8}}}
tf_executor.yield %read3 : tensor<32xf32>
// expected-remark@above {{ID: 8}}
// expected-remark@above {{Predecessors: {4,5,7}}}
// expected-remark@above {{Successors: {9}}}
}
tf_executor.fetch %island#0 : tensor<32xf32>
// expected-remark@above {{ID: 10}}
// expected-remark@above {{Predecessors: {9}}}
// expected-remark@above {{Successors: {11}}}
}
return %graph : tensor<32xf32>
// expected-remark@above {{ID: 12}}
// expected-remark@above {{Predecessors: {11}}}
// expected-remark@above {{Successors: {13}}}
}
// -----
// Tests that the pass tracks control dependencies for reads/writes on the two
// resource handles that refer to the same variable.
// CHECK-LABEL: func @aliasing_reads_writes
func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () {
// expected-remark@above {{ID: 14}}
// expected-remark@above {{Predecessors: {13}}}
tf_executor.graph {
// expected-remark@above {{ID: 12}}
// expected-remark@above {{Predecessors: {11}}}
// expected-remark@above {{Successors: {13}}}
// CHECK: tf_executor.island
%island = tf_executor.island {
// expected-remark@above {{ID: 10}}
// expected-remark@above {{Predecessors: {9}}}
// expected-remark@above {{Successors: {11}}}
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
// expected-remark@above {{ID: 0}}
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
// expected-remark@above {{ID: 1}}
%vh1_id:2 = "tf.IdentityN"(%vh1, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>)
// expected-remark@above {{ID: 2}}
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 3}}
// expected-remark@above {{Successors: {4}}}
"tf.AssignVariableOp"(%vh1_id#0, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 4}}
// expected-remark@above {{Predecessors: {3}}}
// expected-remark@above {{Successors: {5,6}}}
%read1 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 5}}
// expected-remark@above {{Predecessors: {4}}}
// expected-remark@above {{Successors: {7}}}
%read2 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 6}}
// expected-remark@above {{Predecessors: {4}}}
// expected-remark@above {{Successors: {7}}}
"tf.AssignVariableOp"(%vh0, %read2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 7}}
// expected-remark@above {{Predecessors: {5,6}}}
// expected-remark@above {{Successors: {8}}}
"tf.AssignVariableOp"(%vh1_id#0, %read1) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 8}}
// expected-remark@above {{Predecessors: {7}}}
// expected-remark@above {{Successors: {9}}}
tf_executor.yield
// expected-remark@above {{ID: 9}}
// expected-remark@above {{Predecessors: {8}}}
// expected-remark@above {{Successors: {10}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 11}}
// expected-remark@above {{Predecessors: {10}}}
// expected-remark@above {{Successors: {12}}}
}
return
// expected-remark@above {{ID: 13}}
// expected-remark@above {{Predecessors: {12}}}
// expected-remark@above {{Successors: {14}}}
}
// -----
// Tests that the pass tracks control dependencies for side-effecting on unknown
// resources.
// CHECK-LABEL: func @unknown_side_effecting_op
func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
// expected-remark@above {{ID: 13}}
// expected-remark@above {{Predecessors: {12}}}
tf_executor.graph {
// expected-remark@above {{ID: 11}}
// expected-remark@above {{Predecessors: {10}}}
// expected-remark@above {{Successors: {12}}}
// CHECK: tf_executor.island
%island = tf_executor.island {
// expected-remark@above {{ID: 9}}
// expected-remark@above {{Predecessors: {8}}}
// expected-remark@above {{Successors: {10}}}
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
// expected-remark@above {{ID: 0}}
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
// expected-remark@above {{ID: 1}}
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 2}}
// expected-remark@above {{Successors: {4}}}
"tf.AssignVariableOp"(%vh1, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 3}}
// expected-remark@above {{Successors: {4}}}
"tf._UnknownSideEffectingOp_"() : () -> ()
// expected-remark@above {{ID: 4}}
// expected-remark@above {{Predecessors: {2,3}}}
// expected-remark@above {{Successors: {5,6}}}
%read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 5}}
// expected-remark@above {{Predecessors: {4}}}
// expected-remark@above {{Successors: {7}}}
"tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 6}}
// expected-remark@above {{Predecessors: {4}}}
// expected-remark@above {{Successors: {8}}}
"tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 7}}
// expected-remark@above {{Predecessors: {5}}}
// expected-remark@above {{Successors: {8}}}
tf_executor.yield
// expected-remark@above {{ID: 8}}
// expected-remark@above {{Predecessors: {6,7}}}
// expected-remark@above {{Successors: {9}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 10}}
// expected-remark@above {{Predecessors: {9}}}
// expected-remark@above {{Successors: {11}}}
}
return
// expected-remark@above {{ID: 12}}
// expected-remark@above {{Predecessors: {11}}}
// expected-remark@above {{Successors: {13}}}
}
// -----
// Tests that the pass tracks control dependencies for read-only ops on unknown
// resources.
// CHECK-LABEL: func @read_only_unknown_resource
func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () {
// expected-remark@above {{ID: 10}}
// expected-remark@above {{Predecessors: {9}}}
tf_executor.graph {
// expected-remark@above {{ID: 8}}
// expected-remark@above {{Predecessors: {7}}}
// expected-remark@above {{Successors: {9}}}
// CHECK: tf_executor.island
%island = tf_executor.island {
// expected-remark@above {{ID: 6}}
// expected-remark@above {{Predecessors: {5}}}
// expected-remark@above {{Successors: {7}}}
%vh0 = "tf._UnknownSideEffectingOp_"() : () -> tensor<*x!tf.resource<tensor<32xf32>>>
// expected-remark@above {{ID: 0}}
// expected-remark@above {{Successors: {2,3}}}
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
// expected-remark@above {{ID: 1}}
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 2}}
// expected-remark@above {{Predecessors: {0}}}
// expected-remark@above {{Successors: {4}}}
%read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 3}}
// expected-remark@above {{Predecessors: {0}}}
// expected-remark@above {{Successors: {4}}}
"tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 4}}
// expected-remark@above {{Predecessors: {2,3}}}
// expected-remark@above {{Successors: {5}}}
tf_executor.yield
// expected-remark@above {{ID: 5}}
// expected-remark@above {{Predecessors: {4}}}
// expected-remark@above {{Successors: {6}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 7}}
// expected-remark@above {{Predecessors: {6}}}
// expected-remark@above {{Successors: {8}}}
}
return
// expected-remark@above {{ID: 9}}
// expected-remark@above {{Predecessors: {8}}}
// expected-remark@above {{Successors: {10}}}
}

View File

@ -1650,3 +1650,27 @@ func @testSplitSmallSplitDim(%input: tensor<4x8xf32>) {
%0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x8xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
return
}
// -----
func @testTernaryEinsum(%arg0: tensor<2x3xf32>){
// expected-error @+1 {{supports at most two operands}}
%0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>)
return
}
// -----
func @testTopKV2WrongInputRank(%input: tensor<f32>, %k: tensor<i32>) {
// expected-error @+1 {{op requires input operand to have at least 1 dimension}}
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<f32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
return
}
// -----
func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) {
// expected-error @+1 {{op requires k operand to be 0D tensor}}
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>)
return
}

View File

@ -523,7 +523,7 @@ func @invalid_merge(%arg0: tensor<*x!tf.resource>, %arg1: tensor<4x!tf.resource>
// -----
// Check that if result is a ref type, all operands need to be ref too.
func @inavlid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> {
func @invalid_merge(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4xf32>) -> tensor<4x!tf.f32ref> {
%result = tf_executor.graph {
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<4x!tf.f32ref>, tensor<4xf32>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.Merge' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}

View File

@ -68,14 +68,17 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) {
namespace TF {
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
bool enable_logging) {
bool enable_logging,
bool enable_inliner) {
PassManager bridge(module.getContext());
// Add logger to bridge passmanager.
if (enable_logging)
bridge.addInstrumentation(std::make_unique<tensorflow::BridgeLogger>());
CreateTFStandardPipeline(bridge);
StandardPipelineOptions pipeline_options;
pipeline_options.enable_inliner.setValue(enable_inliner);
CreateTFStandardPipeline(bridge, pipeline_options);
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
LogicalResult result = bridge.run(module);
(void)result;

View File

@ -31,11 +31,13 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging);
namespace TF {
// Run all passes involved in transforming or optimizing an MLIR graph without
// Runs all passes involved in transforming or optimizing an MLIR graph without
// any target specialization. When enable_logging is true, enables
// tensorflow::BridgeLogger.
// tensorflow::BridgeLogger. When enable_inliner is true, enables the inliner
// pass.
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
bool enable_logging);
bool enable_logging,
bool enable_inliner);
} // namespace TF

View File

@ -29,10 +29,4 @@ mlir::PassPipelineRegistration<> tpu_pipeline(
"that it is suitable for targeting TPUs.",
mlir::TFTPU::CreateTPUBridge);
mlir::PassPipelineRegistration<> standard_pipeline(
"tf-standard-bridge",
"Run all passes involved in transforming or optimizing an MLIR graph"
"without any target specialization.",
mlir::TF::CreateTFStandardPipeline);
} // anonymous namespace

View File

@ -22,6 +22,9 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def SingleResultAndOperandHaveSameElementType : Constraint<
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
def SingleResultAndOperandHaveSameType : Constraint<
CPred<"$0->getType() == $1->getType()">>;
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;
//===----------------------------------------------------------------------===//
@ -75,8 +78,7 @@ def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
def CastSameType : Pat<(TF_CastOp:$res $arg, $truncate),
(replaceWithValue $arg),
[(SingleResultAndOperandHaveSameElementType $res,
$arg)]>;
[(SingleResultAndOperandHaveSameType $res, $arg)]>;
//===----------------------------------------------------------------------===//
// Conj op patterns.

View File

@ -45,7 +45,8 @@ struct TFOptimizePass : public FunctionPass<TFOptimizePass> {
} // namespace
// NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
void CreateTFStandardPipeline(OpPassManager &pm) {
void CreateTFStandardPipeline(OpPassManager &pm,
const StandardPipelineOptions &options) {
OpPassManager &func_pm = pm.nest<FuncOp>();
// First operates on the executor dialect:
@ -59,8 +60,12 @@ void CreateTFStandardPipeline(OpPassManager &pm) {
// Hopefully there is a single island left, or there wasn't any to begin with.
// We now run the optimizer which operates mostly inside islands.
func_pm.addPass(createCanonicalizerPass());
func_pm.addPass(CreateTFOptimizePass());
func_pm.addPass(createCSEPass());
if (options.enable_inliner) {
pm.addPass(createInlinerPass());
}
pm.addNestedPass<FuncOp>(CreateTFShapeInferencePass());
pm.addNestedPass<FuncOp>(CreateTFOptimizePass());
pm.addNestedPass<FuncOp>(createCSEPass());
}
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass() {
@ -70,7 +75,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass() {
static PassRegistration<TFOptimizePass> pass("tf-optimize", "Optimizes TF.");
// Registers a pipeline builder function for the default canonicalize/optimizer.
static mlir::PassPipelineRegistration<> pipeline(
static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
"tf-standard-pipeline",
"Run all the passes involved in transforming/optimizing the graph after "
"importing into MLIR, without any target specialization.",

View File

@ -46,10 +46,17 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateTFShapeInferencePass();
// Optimizes Tensorflow graph.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass();
struct StandardPipelineOptions : public PassOptions<StandardPipelineOptions> {
Option<bool> enable_inliner{*this, "enable-inliner",
llvm::cl::desc("Enable inliner."),
llvm::cl::init(false)};
};
// Propagates the pass manager with the passes involved in transforming or
// optimizing an MLIR graph without any target specialization.
// NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
void CreateTFStandardPipeline(OpPassManager& pm);
void CreateTFStandardPipeline(OpPassManager& pm,
const StandardPipelineOptions& options);
} // namespace TF
namespace TFControlFlow {

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
namespace mlir {
namespace tf_executor {
namespace {
// A pass that adds "Predecessors" and "Successors" remarks for each op based on
// SideEffectAnalysis result. For testing purpose only.
struct TestSideEffectAnalysis
: public mlir::FunctionPass<TestSideEffectAnalysis> {
void runOnFunction() override {
int64_t next_id = 0;
llvm::SmallDenseMap<Operation*, int64_t, 8> ids;
getFunction().walk([&](Operation* op) {
ids[op] = next_id++;
op->emitRemark("ID: ") << ids[op];
});
auto join_ids = [&](const llvm::ArrayRef<Operation*> ops) {
llvm::SmallVector<std::string, 8> id_vec;
id_vec.reserve(ops.size());
for (auto op : ops) id_vec.push_back(std::to_string(ids[op]));
return llvm::join(id_vec, ",");
};
auto& analysis = getAnalysis<TF::SideEffectAnalysis>();
getFunction().walk([&](Operation* op) {
if (!analysis.DirectControlPredecessors(op).empty()) {
op->emitRemark("Predecessors: ")
<< "{" << join_ids(analysis.DirectControlPredecessors(op)) << "}";
}
if (!analysis.DirectControlSuccessors(op).empty()) {
op->emitRemark("Successors: ")
<< "{" << join_ids(analysis.DirectControlSuccessors(op)) << "}";
}
});
}
};
static mlir::PassRegistration<TestSideEffectAnalysis> pass(
"tf-test-side-effect-analysis",
"Add remarks based on side-effect analysis result, for testing purpose.");
} // anonymous namespace
} // namespace tf_executor
} // namespace mlir

View File

@ -141,7 +141,7 @@ static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
// NOLINTNEXTLINE
static llvm::cl::list<std::string> cl_pass_list(
"graph-passes", llvm::cl::value_desc("list"),
llvm::cl::desc("comma seprarated list of GraphOptimizationPass to run."),
llvm::cl::desc("comma separated list of GraphOptimizationPass to run."),
llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory));
class GraphOptByNamePass : public GraphOptPass {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/escaping.h"
#include "absl/strings/numbers.h"
@ -75,6 +76,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
@ -500,7 +502,8 @@ Status ImporterBase::GetInputOutputNodes(
TF_RETURN_IF_ERROR(add_node(input.first));
}
for (const auto& output_node_name : specs_.output_arrays) {
for (const auto& output : specs_.outputs) {
auto output_node_name = std::string(ParseTensorName(output).first);
TF_RETURN_IF_ERROR(add_node(output_node_name));
}
@ -535,7 +538,7 @@ Status ImporterBase::AddNodesToShapeRefiner() {
auto node_name = node->op_def().name();
if (node_name != "Placeholder" && node_name != "LegacyFedInput" &&
node_name != FunctionLibraryDefinition::kArgOp) {
// We do not handle the case where the input node has multple outputs
// We do not handle the case where the input node has multiple outputs
if (node->num_outputs() > 1) {
return errors::FailedPrecondition(absl::StrCat(
"Input arrays can only have op with single output. Node op:",
@ -1588,7 +1591,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
if (specs.graph_as_function) {
if (specs.prune_unused_nodes || !specs.inputs.empty() ||
!specs.output_arrays.empty() || !specs.output_arrays_order.empty())
!specs.outputs.empty())
return errors::InvalidArgument(
"Pruning of graph is currently unsupported when the main graph is "
"converted to a function.");
@ -1622,7 +1625,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
// TODO(prakalps): Refactor to keep attribute strings (tf.entry_function,
// tf.versions) shared by importer and exporter in a centralized place.
// Record the input and output mapping.
if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
if (!specs.inputs.empty() || !specs.outputs.empty()) {
mlir::Builder b(context);
std::string s;
llvm::raw_string_ostream ss(s);
@ -1632,7 +1635,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
",");
auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str()));
s.clear();
mlir::interleave(specs.output_arrays_order, ss, ",");
mlir::interleave(specs.outputs, ss, ",");
auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str()));
attrs.push_back(b.getNamedAttr("tf.entry_function",
@ -1665,9 +1668,13 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
// Finds out all the input nodes and output nodes.
if (!specs.inputs.empty() || !specs.output_arrays.empty()) {
absl::flat_hash_set<absl::string_view> output_node_names;
for (const auto& output_tensor : specs.outputs) {
output_node_names.insert(ParseTensorName(output_tensor).node());
}
if (!specs.inputs.empty() || !specs.outputs.empty()) {
arg_nodes->resize(specs.inputs.size());
ret_nodes->resize(specs.output_arrays_order.size());
ret_nodes->resize(specs.outputs.size());
for (Node* n : GetOrderedNodes()) {
// Handle inputs/arguments.
@ -1677,17 +1684,17 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
}
// Handle outputs/returns.
if (specs.output_arrays.find(n->name()) != specs.output_arrays.end()) {
for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
if (output_node_names.contains(n->name())) {
for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
std::pair<std::string, std::string> name_and_port =
absl::StrSplit(specs.output_arrays_order[i], ':');
absl::StrSplit(specs.outputs[i], ':');
auto name = name_and_port.first;
if (name != n->name()) continue;
int port = 0;
if (!name_and_port.second.empty() &&
!absl::SimpleAtoi(name_and_port.second, &port)) {
return errors::InvalidArgument("Invalid port specification: ",
specs.output_arrays_order[i]);
specs.outputs[i]);
}
(*ret_nodes)[i] = {n, port};
}
@ -1726,10 +1733,10 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
}
llvm::SmallVector<mlir::Type, 4> ret_types;
ret_types.reserve(specs.output_arrays.size());
for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
ret_types.reserve(specs.outputs.size());
for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
if (ret_nodes->at(i).node == nullptr) {
return errors::InvalidArgument("Output ", specs.output_arrays_order[i],
return errors::InvalidArgument("Output ", specs.outputs[i],
" was not found in graph");
}
}

View File

@ -33,19 +33,16 @@ limitations under the License.
namespace tensorflow {
Status ParseOutputArrayInfo(absl::string_view array_names,
absl::flat_hash_set<string>* array,
std::vector<string>* order) {
std::vector<string>* outputs) {
std::vector<string> output_names = absl::StrSplit(array_names, ',');
return ParseOutputArrayInfo(output_names, array, order);
return ParseOutputArrayInfo(output_names, outputs);
}
Status ParseOutputArrayInfo(const std::vector<string>& output_names,
absl::flat_hash_set<string>* array,
std::vector<string>* order) {
std::vector<string>* outputs) {
for (auto& output_name : output_names) {
if (output_name.empty()) continue;
array->insert(string(*absl::StrSplit(output_name, ':').begin()));
order->push_back(output_name);
outputs->push_back(output_name);
}
return Status::OK();
}

View File

@ -40,11 +40,9 @@ struct GraphImportConfig {
llvm::MapVector<string, ArrayInfo, llvm::StringMap<unsigned>>;
// Maps input node names to node data types and shapes.
InputArrays inputs;
// Output node names.
absl::flat_hash_set<string> output_arrays;
// nodes:index strings for the output as specified on the command line.
std::vector<string> output_arrays_order;
// setting prune_unused_nodes to true, would prune unreachable nodes if
// name:index strings for the output as specified on the command line.
std::vector<string> outputs;
// Setting prune_unused_nodes to true, would prune unreachable nodes if
// output_arrays is specified.
bool prune_unused_nodes = false;
// If true, inputs of type LegacyFedInput are replaced with Placeholder ops.
@ -73,12 +71,10 @@ struct GraphExportConfig {
// Parses the command line flag strings to the specification of nodes in
// the Graph.
Status ParseOutputArrayInfo(absl::string_view array_names,
absl::flat_hash_set<string>* array,
std::vector<string>* order);
std::vector<string>* outputs);
Status ParseOutputArrayInfo(const std::vector<string>& output_names,
absl::flat_hash_set<string>* array,
std::vector<string>* order);
std::vector<string>* outputs);
// Parses the command line flag strings to the specification of nodes in
// the Graph. `data_types` input string can be empty since the flag is optional.

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
@ -63,16 +64,18 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
specs.upgrade_legacy = upgrade_legacy;
TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
input_shapes, &specs.inputs));
TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.output_arrays,
&specs.output_arrays_order));
TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs));
// TODO(b/142828368): Pruning should not be needed when TF import
// supports importing graphs w/ unregistered ops natively.
GraphDef pruned_graph_def;
if (specs.prune_unused_nodes) {
std::vector<string> terminal_nodes(specs.output_arrays.begin(),
specs.output_arrays.end());
for (const auto entry : specs.inputs) {
terminal_nodes.push_back(entry.first);
std::vector<std::string> terminal_nodes;
terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size());
for (const auto& output : specs.outputs) {
terminal_nodes.push_back(std::string(ParseTensorName(output).node()));
}
for (const auto& input : specs.inputs) {
terminal_nodes.push_back(input.first);
}
TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
graphdef, &pruned_graph_def, terminal_nodes));

View File

@ -36,7 +36,7 @@ xla::StatusOr<xla::Shape> TestShapeRepresentation(const TensorShape& shape,
return xla_shape;
}
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerliazedMlirModule) {
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
string invalid_mlir_module = "totally @invalid MLIR module {here} <-";
std::vector<TensorShape> arg_shapes;
XlaCompiler::CompilationResult compilation_result;
@ -101,7 +101,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
xla::ShapeUtil::MakeTupleShape({output_shape});
EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape);
// Expect exactly 1 OutputDescrpition.
// Expect exactly 1 OutputDescription.
EXPECT_EQ(compilation_result.outputs.size(), 1);
const XlaCompiler::OutputDescription& output_desc =
compilation_result.outputs.front();

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Support/ToolUtilities.h" // TF:local_config_mlir
#include "mlir/Support/TranslateClParser.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
@ -40,6 +41,13 @@ static llvm::cl::opt<std::string> output_filename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each chunk "
"independently"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> import_saved_model(
"savedmodel-to-mlir",
@ -85,13 +93,12 @@ int main(int argc, char** argv) {
return 1;
}
mlir::MLIRContext context;
if (import_saved_model) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
mlir::MLIRContext context;
auto module = tensorflow::SavedModelToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names),
@ -107,12 +114,23 @@ int main(int argc, char** argv) {
return 1;
}
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
mlir::SourceMgrDiagnosticHandler diagnostic_handler(source_mgr, &context);
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
llvm::raw_ostream& os) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
mlir::MLIRContext context;
mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context);
return (*requested_translation)(sourceMgr, os, &context);
};
if (failed((*requested_translation)(source_mgr, output->os(), &context)))
return 1;
if (splitInputFile) {
if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer,
output->os())))
return 1;
} else {
if (failed(processBuffer(std::move(input), output->os()))) return 1;
}
}
output->keep();

View File

@ -404,6 +404,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/service:hlo",
"@llvm//:support",
"@local_config_mlir//:Analysis",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
namespace mlir {
namespace xla {
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
@ -82,3 +83,4 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
}
} // namespace xla
} // namespace mlir

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
namespace mlir {
namespace xla {
// Converts the given elements attr to the specified elements type.
@ -27,5 +28,6 @@ namespace xla {
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type);
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_

View File

@ -47,13 +47,11 @@ limitations under the License.
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc"
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
namespace mlir {
#include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::xla_hlo;
namespace xla_hlo {
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
Attribute value, Type type,
@ -160,7 +158,7 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) {
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
value.isa<IntegerAttr>()) {
// All XLA types must be tensor types. In the build() method, we want to
// provide more flexiblity by allowing attributes of scalar types. But we
// provide more flexibility by allowing attributes of scalar types. But we
// need to wrap it up with ElementsAttr to construct valid XLA constants.
type = RankedTensorType::get(/*shape=*/{}, value.getType());
value = DenseElementsAttr::get(type.cast<TensorType>(), value);
@ -212,9 +210,9 @@ void AbsOp::build(Builder* builder, OperationState& result, Value* operand) {
new_type = operand->getType();
} else if (shaped_type.hasRank()) {
new_type =
mlir::RankedTensorType::get(shaped_type.getShape(), operand->getType());
RankedTensorType::get(shaped_type.getShape(), operand->getType());
} else {
new_type = mlir::UnrankedTensorType::get(operand->getType());
new_type = UnrankedTensorType::get(operand->getType());
}
return AbsOp::build(builder, result, new_type, operand);
@ -241,8 +239,8 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
// If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
return ::xla::ConvertElementsAttr(elementsAttr,
getElementTypeOrSelf(getResult()));
return xla::ConvertElementsAttr(elementsAttr,
getElementTypeOrSelf(getResult()));
}
return {};
@ -436,7 +434,7 @@ static LogicalResult Verify(ClampOp op) {
void ComplexOp::build(Builder* builder, OperationState& state, Value* lhs,
Value* rhs) {
auto type = lhs->getType();
auto element_ty = mlir::ComplexType::get(getElementTypeOrSelf(type));
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
Type result_ty;
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
@ -843,6 +841,70 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand,
return RankedTensorType::get(shape, ranked_ty.getElementType());
}
//===----------------------------------------------------------------------===//
// SortOp
//===----------------------------------------------------------------------===//
void SortOp::build(Builder* builder, OperationState& state,
ArrayRef<Value*> operands, int64_t dimension,
bool is_stable) {
state.addOperands(operands);
state.addAttribute("dimension", builder->getI64IntegerAttr(dimension));
state.addAttribute("is_stable", builder->getBoolAttr(dimension));
SmallVector<Type, 2> element_types;
element_types.reserve(operands.size());
for (Value* operand : operands) element_types.push_back(operand->getType());
state.addTypes(builder->getTupleType(element_types));
state.addRegion();
}
static LogicalResult Verify(SortOp op) {
Operation::operand_range operands = op.operands();
if (operands.empty()) return op.emitOpError("requires at least one input");
// TODO(antiagainst): verify partionally dynamic shapes
if (llvm::all_of(operands, [](Value* operand) {
return operand->getType().cast<ShapedType>().hasRank();
})) {
ArrayRef<int64_t> input_shape =
(*operands.begin())->getType().cast<ShapedType>().getShape();
if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value* operand) {
return operand->getType().cast<ShapedType>().getShape() !=
input_shape;
}))
return op.emitOpError("requires all inputs to have the same dimensions");
if (op.dimension().getSExtValue() >= input_shape.size())
return op.emitOpError(
"dimension attribute value must be less than input rank");
}
Block& block = op.comparator().front();
size_t num_operands = op.getOperation()->getNumOperands();
if (block.getNumArguments() != 2 * num_operands)
return op.emitOpError("comparator block should have ")
<< 2 * num_operands << " arguments";
for (auto indexed_operand : llvm::enumerate(operands)) {
int index = indexed_operand.index();
Type element_type =
indexed_operand.value()->getType().cast<ShapedType>().getElementType();
Type tensor_type = RankedTensorType::get({}, element_type);
for (int i : {2 * index, 2 * index + 1}) {
Type arg_type = block.getArgument(i)->getType();
if (arg_type != tensor_type)
return op.emitOpError("comparator block argument #")
<< i << " should be of type " << tensor_type << " but got "
<< arg_type;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
@ -938,6 +1000,15 @@ void TupleOp::build(Builder* builder, OperationState& result,
build(builder, result, builder->getTupleType(types), values);
}
//===----------------------------------------------------------------------===//
// UnaryEinsumOp
//===----------------------------------------------------------------------===//
void UnaryEinsumOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<UnaryEinsumToEinsum>(context);
}
//===----------------------------------------------------------------------===//
// CompareOp
//===----------------------------------------------------------------------===//
@ -990,3 +1061,6 @@ XlaHloDialect::XlaHloDialect(MLIRContext* context)
// Support unknown operations because not all XLA operations are registered.
// allowUnknownOperations();
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -437,9 +437,6 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO
let builders = [OpBuilder<
"Builder *builder, OperationState &results, "
"Value* value, int32_t index">];
// GetTupleElementOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}
def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
@ -730,6 +727,43 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral
let results = (outs HLO_Tensor);
}
def BASE_EinsumOp {
string summary = "Einsum operator";
string description = [{
Returns a tensor whose elements are defined by equation, which is written
in a shorthand form inspired by the Einstein summation convention.
}];
}
def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]> {
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
StrAttr:$einsum_config
);
let results = (outs HLO_Tensor);
// TODO(hinsu): Canonicalize to lower this client side HLO op to server
// side HLO ops.
}
def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]> {
let arguments = (ins
HLO_Tensor:$operand,
StrAttr:$einsum_config
);
let results = (outs HLO_Tensor);
let hasCanonicalizer = 1;
// UnarayEinsumOp is unconditionally canonicalized to the binary EinsumOp so
// the HLO converter shouldn't be invoked.
let hasCustomHLOConverter = 1;
}
def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
let arguments = (ins
HLO_Tensor:$operand,
@ -834,6 +868,26 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",
let hasCustomHLOConverter = 1;
}
def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp {
let arguments = (ins
Variadic<HLO_Tensor>:$operands,
DefaultValuedAttr<I64Attr, "-1">:$dimension,
DefaultValuedAttr<BoolAttr, "false">:$is_stable
);
let results = (outs HLO_TensorOrTuple);
let regions = (region SizedRegion<1>:$comparator);
let builders = [OpBuilder<
"Builder *builder, OperationState &state, ArrayRef<Value *> operands, "
"int64_t dimension, bool is_stable"
>];
// TODO(b/129422361): SortOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}
def HLO_ReverseOp: HLO_Op<"reverse",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp {
let arguments = (ins

View File

@ -708,7 +708,7 @@ class BASE_HLO_ClampOp {
}
class BASE_HLO_ConcatenateOp {
string summary = "XLA's concantenate op";
string summary = "XLA's concatenate op";
string description = [{
Concatenates a set of tensors along the specified dimension.
@ -832,6 +832,17 @@ class BASE_HLO_SelectAndScatterOp {
}];
}
class BASE_HLO_SortOp {
string summary = "Sort operator";
string description = [{
Sorts the given `operands` at the given `dimension` with the given
`comparator`.
See https://www.tensorflow.org/xla/operation_semantics#sort.
}];
}
class BASE_HLO_ReverseOp {
string summary = "Reverse operator";

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <numeric>
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
namespace mlir {
namespace xla {
@ -51,5 +53,18 @@ DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x,
return DenseIntElementsAttr::get(type, broadcastDimensions);
}
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
DenseElementsAttr attr;
if (auto float_ty = ty.dyn_cast<FloatType>()) {
APFloat value(float_ty.getFloatSemantics(), raw_value);
return DenseElementsAttr::get(scalar_ty, value);
}
auto int_ty = ty.cast<IntegerType>();
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
return DenseElementsAttr::get(scalar_ty, value);
}
} // namespace xla
} // namespace mlir

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
@ -48,6 +49,12 @@ static ElementsAttr getSplat(Builder* b, Value* val, T constant) {
return DenseElementsAttr::get(valType, elementAttr);
}
// Returns DenseElementsAttr of rank zero with the given element type and the
// value.
// Requires `ty` to be either FloatType of IntegerType.
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
} // namespace xla
} // namespace mlir

View File

@ -18,20 +18,23 @@ limitations under the License.
#ifndef HLO_UTILS
#define HLO_UTILS
#ifndef OP_BASE
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
class ConstantSplat<string value> : NativeCodeCall<
"getSplat(&$_builder, $0, " # value # ")">;
"xla::getSplat(&$_builder, $0, " # value # ")">;
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
def BinBroadcastDimensions : NativeCodeCall<
"getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
// Here, the element type can be any integer or float type. But, note that only
// 32 bit integers are supported for the value.
class GetScalarOfType<int value> : NativeCodeCall<
"xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
#endif // HLO_UTILS

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
@ -77,6 +78,10 @@ static double ConvertAPFloat(llvm::APFloat value) {
return value.convertToDouble();
}
static absl::string_view ConvertStringRef(mlir::StringRef value) {
return {value.data(), value.size()};
}
static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
auto values = attr.getValues<int64>();
return {values.begin(), values.end()};
@ -494,13 +499,6 @@ LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) {
return failure();
}
LogicalResult ExportXlaOp(GetTupleElementOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
value_map[op] = xla::GetTupleElement(value_map[op.getOperand()],
op.index().getSExtValue());
return success();
}
LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
@ -626,12 +624,30 @@ LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
return failure();
}
LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
xla::XlaComputation comparator;
if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
&comparator)))
return failure();
auto& value_map = *ctx.values;
value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator,
op.dimension().getSExtValue(), op.is_stable());
return success();
}
LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx));
return success();
}
LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
// Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
// operands.
return failure();
}
LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
xla::XlaComputation condition;
xla::XlaComputation body;
@ -773,7 +789,7 @@ LogicalResult ConvertToHloModule::LowerFunctionCall(
LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
if (lowered_computation_.count(f)) return success();
if (f.getBlocks().size() != 1) {
return f.emitError("only single block Function suppored");
return f.emitError("only single block Function supported");
}
// Create a sub-builder if this is not the main function.

View File

@ -32,17 +32,20 @@ using llvm::raw_ostream;
using llvm::RecordKeeper;
using llvm::StringRef;
using mlir::interleaveComma;
using mlir::tblgen::Attribute;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::Operator;
static std::string GetDefaultAttrExport(
const mlir::tblgen::NamedAttribute& named_attr) {
auto storage_type = named_attr.attr.getStorageType();
Attribute attr = named_attr.attr;
StringRef storage_type = attr.getStorageType();
// For some attribute types we have a general conversion, so use that.
if (storage_type.endswith("IntegerAttr") ||
storage_type.endswith("FloatAttr")) {
return "Convert" + named_attr.attr.getReturnType().str();
if (!attr.isEnumAttr() && (storage_type.endswith("IntegerAttr") ||
storage_type.endswith("FloatAttr") ||
storage_type.endswith("StringAttr"))) {
return "Convert" + attr.getReturnType().str();
}
return "Convert_" + named_attr.name.str();
}

View File

@ -48,3 +48,11 @@ func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f
// CHECK: return %arg0
return %2 : tensor<4xcomplex<f32>>
}
// CHECK-LABEL: @unary_einsum
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}

View File

@ -180,6 +180,27 @@ func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
return %0: tensor<?xi1>
}
// CHECK-LABEL: func @bitwise_or
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: xla_hlo.or
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0: tensor<4xi32>
}
// CHECK-LABEL: func @bitwise_or_broadcast
func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
// CHECK-NEXT: xla_hlo.or
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
return %0: tensor<1x4xi8>
}
// CHECK-LABEL: func @bitwise_or_dynamic
func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
// CHECK-NEXT: xla_hlo.or
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0: tensor<?xi32>
}
// CHECK-LABEL: func @pow
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-NEXT: xla_hlo.pow
@ -194,6 +215,20 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
return %0: tensor<?xf32>
}
// CHECK-LABEL: func @einsum
func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> {
// CHECK: xla_hlo.einsum
%0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
return %0: tensor<2x4xf32>
}
// CHECK-LABEL: func @unary_einsum
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK: xla_hlo.unary_einsum
%0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
return %0: tensor<2x2xf32>
}
// CHECK-LABEL: func @floordiv_broadcast_i32
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
// CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0>
@ -589,6 +624,17 @@ func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
return %1 : tensor<6x9xf32>
}
// CHECK-LABEL: func @padv2_i32_paddings
func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
%padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32>
// CHECK: "xla_hlo.pad"(%arg0, %arg1) {
// CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>,
// CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>,
// CHECK-SAME: interior_padding = dense<0> : tensor<2xi64>
%1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<6x9xf32>
return %1 : tensor<6x9xf32>
}
//===----------------------------------------------------------------------===//
// Identity op legalizations.
//===----------------------------------------------------------------------===//
@ -1888,3 +1934,42 @@ func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf
// CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
}
//===----------------------------------------------------------------------===//
// tf.TopKV2 legalization
//===----------------------------------------------------------------------===//
// CHECK-LABEL: topk_v2_non_const_k
func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
// CHECK: tf.TopKV2
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
return %0#0, %0#1: tensor<?xf32>, tensor<?xi32>
}
// CHECK-LABEL: topk_v2_unknown_input_last_dim
func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) {
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
// CHECK: tf.TopKV2
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>)
return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32>
}
// CHECK-LABEL: topk_v2
// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32>
func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64}
// CHECK-NEXT: %[[SORT:.*]] = "xla_hlo.sort"(%[[INPUT]], %[[IOTA]]) ( {
// CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f32>, %[[RHS:.*]]: tensor<f32>, %{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
// CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"}
// CHECK-NEXT: "xla_hlo.return"(%[[CMP]])
// CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
// CHECK-NEXT: %[[TUPL0:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32}
// CHECK-NEXT: %[[TUPL1:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32}
// CHECK-NEXT: %[[VAL:.*]] = "xla_hlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-NEXT: %[[IDX:.*]] = "xla_hlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-NEXT: return %[[VAL]], %[[IDX]]
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
}

View File

@ -1,7 +1,7 @@
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s
#map0 = (d0, d1) -> (d0, d1)
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], n_loop_types = [2, 0, 0], n_views = [2, 1]}
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"], n_views = [2, 1]}
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
%temp_result = alloc() {temp = true} : memref<2x2xf32>

View File

@ -416,3 +416,98 @@ func @constants() -> () {
%3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor<i32>} : () -> (tensor<*xi32>)
return
}
// -----
func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// CHECK: xla_hlo.sort
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_no_operands() {
// expected-error @+1 {{op requires at least one input}}
%0 = "xla_hlo.sort"() ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
%7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : () -> tuple<>
return
}
// -----
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{comparator block argument #0 should be of type 'tensor<f32>' but got 'tensor<i32>'}}
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op requires all inputs to have the same dimensions}}
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op dimension attribute value must be less than input rank}}
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op comparator block should have 4 arguments}}
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op comparator block argument #3 should be of type 'tensor<i32>' but got 'tensor<f32>'}}
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}

View File

@ -1,26 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
%0 = "xla_hlo.all_reduce"(%arg0) ({
// Perform max reduction inside the region
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.max %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
})
{
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
channel_id = {
handle = 5 : i64,
type = 2 : i64
}
} : (tensor<10xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
// CHECK-SAME: channel_id=5
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
// CHECK-SAME: to_apply=%[[COMPUTATION]]

View File

@ -1,15 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
%0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
}
// CHECK-LABEL: ENTRY
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
// CHECK: [[VAL_4:%.*]] = f32[2] parameter(3)
// CHECK: [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4)
// CHECK-LABEL: ROOT
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0

View File

@ -1,13 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
%0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
}
// CHECK-LABEL: ENTRY
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
// CHECK-LABEL: ROOT
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3

View File

@ -1,23 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
module {
func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
// CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
// CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
// CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
// CHECK-LABEL: ROOT
// CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
}
}

View File

@ -1,26 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
// CHECK-LABEL: ENTRY %main.13 (Arg_0.1: s32[1,4], Arg_1.2: s32[2,4], Arg_2.3: s32[2,3,4]) -> s32[2,3,4] {
func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
// Same rank degenerate broadcast
// CHECK-NEXT: %Arg_0.1 = s32[1,4] parameter(0)
// CHECK-NEXT: %reshape.4 = s32[4] reshape(s32[1,4] %Arg_0.1)
// CHECK-NEXT: %broadcast.5 = s32[2,4] broadcast(s32[4] %reshape.4)
// CHECK-NEXT: %Arg_1.2 = s32[2,4] parameter(1)
// CHECK-NEXT: %add.6 = s32[2,4] add(s32[2,4] %broadcast.5, s32[2,4] %Arg_1.2)
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
// Broadcast up rank
// CHECK-NEXT: %broadcast.7 = s32[2,3,4] broadcast(s32[2,4] %Arg_1.2), dimensions={0,2}
// CHECK-NEXT: %Arg_2.3 = s32[2,3,4] parameter(2)
// CHECK-NEXT: %add.8 = s32[2,3,4] add(s32[2,3,4] %broadcast.7, s32[2,3,4] %Arg_2.3)
%1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
// Broadcast up rank + degenerate broadcast
// CHECK-NEXT: %broadcast.9 = s32[2,1,4] broadcast(s32[1,4] %Arg_0.1), dimensions={1,2}
// CHECK-NEXT: %reshape.10 = s32[2,4] reshape(s32[2,1,4] %broadcast.9)
// CHECK-NEXT: %broadcast.11 = s32[2,3,4] broadcast(s32[2,4] %reshape.10), dimensions={0,2}
// CHECK-NEXT: ROOT %add.12 = s32[2,3,4] add(s32[2,3,4] %broadcast.11, s32[2,3,4] %Arg_2.3)
%2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
return %2 : tensor<2x3x4xi32>
}

View File

@ -1,9 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
// CHECK-LABEL: ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[1,2,3,4] {
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
// CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
// CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
return %0 : tensor<1x2x3x4xi32>
}

View File

@ -1,12 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
%result = "xla_hlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<1xf32>) -> tensor<1x10xf32>
return %result : tensor<1x10xf32>
}
// CHECK: ENTRY %main.3 ([[ARG0:.*]]: f32[1]) -> f32[1,10] {
// CHECK: %[[ARG0]] = f32[1] parameter(0)
// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] %[[ARG0]]), dimensions={0}

View File

@ -1,30 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
%0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %1 : tensor<4xi32>
}
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK: [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s32[4], [[ARG_2:.*]]: s32[4]) -> s32[4] {
// CHECK: %[[ARG_1]] = s32[4] parameter(0)
// CHECK: %[[ARG_2]] = s32[4] parameter(1)
// CHECK-LABEL: ROOT
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]])
// CHECK: [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] {
// CHECK: %[[ARG_3]] = s32[4] parameter(0)
// CHECK: %[[ARG_4]] = s32[4] parameter(1)
// CHECK-LABEL: ROOT
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]])
// CHECK: ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] {
// CHECK: %[[ARG]] = s32[4] parameter(0)
// CHECK: [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]]
// CHECK-LABEL: ROOT
// CHECK-SAME: s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]]

View File

@ -1,24 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
%0:2 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>)
return %0#0, %0#1 : tensor<4xi32>, tensor<4xi32>
}
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0, %1 : tensor<4xi32>, tensor<4xi32>
}
// Get name of callee computation
// CHECK: [[CALLEE:%.*]] ({{.*}}) -> ({{.*}}) {
// CHECK-LABEL: ENTRY
// CHECK-SAME: [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) {
// CHECK: %[[ARG]] = s32[4] parameter(0)
// CHECK: [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]]
// CHECK: [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0
// CHECK: [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1
// CHECK-LABEL: ROOT
// CHECK-SAME: (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]])

View File

@ -1,17 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0 : tensor<5x2xf32>,
%arg1 : tensor<5x5xf32>,
%arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
%result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) {
dimension = 1 : i64
} : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
return %result : tensor<5x14xf32>
}
// CHECK-LABEL: main
// CHECK: %[[ARG0:.*]] = f32[5,2] parameter(0)
// CHECK: %[[ARG1:.*]] = f32[5,5] parameter(1)
// CHECK: %[[ARG2:.*]] = f32[5,7] parameter(2)
// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1}

View File

@ -42,13 +42,13 @@ func @main(%arg0: tensor<f32>) -> tuple<tensor<f32>> {
// CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]]
%2 = "xla_hlo.conditional"(%0, %1, %1) ( {
^bb0(%arg1: tuple<tensor<f32>>): // no predecessors
^bb0(%arg1: tuple<tensor<f32>>):
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%7 = "xla_hlo.log"(%6) : (tensor<f32>) -> tensor<f32>
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
}, {
^bb0(%arg1: tuple<tensor<f32>>): // no predecessors
^bb0(%arg1: tuple<tensor<f32>>):
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%7 = "xla_hlo.exp"(%6) : (tensor<f32>) -> tensor<f32>
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>

View File

@ -1,30 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: ENTRY %main
func @main() -> tensor<2x2x1x1xf32> {
// CHECK: constant.{{.*}} = s64[] constant(1)
%cst = constant dense<1> : tensor<i64>
// CHECK: constant.{{.*}} = f32[2,2,1,1]
// CHECK-SAME: { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } }
%cst_0 = constant dense<
[[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]]
> : tensor<2x2x1x1xf32>
// CHECK: s32[1] constant({1})
%cst_1 = constant dense<1> : tensor<1xi32>
// CHECK: %[[C:.*]] = s32[] constant(1)
// CHECK: s32[10] broadcast(s32[] %[[C]])
%cst_2 = constant dense<1> : tensor<10xi32>
// CHECK: s32[4] constant({1, 2, 3, 4})
%cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: s32[2,2] constant({ { 1, 2 }, { 3, 4 } })
%cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
// CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
%cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
return %cst_0 : tensor<2x2x1x1xf32>
}

View File

@ -1,31 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
%result = "xla_hlo.conv"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 3 : i64,
kernel_output_feature_dimension = 2 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
lhs_dilation = dense<1> : tensor<2xi64>,
padding = dense<2> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>
} : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32>
return %result : tensor<100x28x28x1xf32>
}
// CHECK-LABEL: main
// CHECK: %[[ARG0:.*]] = f32[100,26,26,32] parameter(0)
// CHECK: %[[ARG1:.*]] = f32[3,3,1,32] parameter(1)
// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]),
// CHECK-SAME: window={size=3x3 pad=2_2x2_2},
// CHECK-SAME: dim_labels=b01f_01oi->b01f

View File

@ -1,10 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK: ENTRY %main
// CHECK: %[[ARG:.*]] = s32[2] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]])

View File

@ -1,10 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
// CHECK: ENTRY %main
// CHECK: [[ARG:%.*]] = s32[2] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]])

View File

@ -1,16 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
%0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
%1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
return %1 : tensor<10xf32>
}
// CHECK: %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[]
// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
// CHECK: ENTRY %main
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
// CHECK-SAME: to_apply=%[[SUM_COMPUTATION]]

View File

@ -0,0 +1,640 @@
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
%0 = "xla_hlo.all_reduce"(%arg0) ({
// Perform max reduction inside the region
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.max %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
})
{
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
channel_id = {
handle = 5 : i64,
type = 2 : i64
}
} : (tensor<10xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
// CHECK-SAME: channel_id=5
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
// CHECK-SAME: to_apply=%[[COMPUTATION]]
// -----
// CHECK-LABEL: HloModule
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
%0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
}
// CHECK-LABEL: ENTRY
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
// CHECK: [[VAL_4:%.*]] = f32[2] parameter(3)
// CHECK: [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4)
// CHECK-LABEL: ROOT
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0
// -----
// CHECK-LABEL: HloModule
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
%0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
}
// CHECK-LABEL: ENTRY
// CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0)
// CHECK: [[VAL_2:%.*]] = f32[2] parameter(1)
// CHECK: [[VAL_3:%.*]] = f32[2] parameter(2)
// CHECK-LABEL: ROOT
// CHECK-SAME: (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
// CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
// CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
// CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
// CHECK-LABEL: ROOT
// CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
// Same rank degenerate broadcast
// CHECK: [[ARG_0:%.*]] = s32[1,4] parameter(0)
// CHECK-NEXT: [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]])
// CHECK-NEXT: [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]])
// CHECK-NEXT: [[ARG_1:%.*]] = s32[2,4] parameter(1)
// CHECK-NEXT: s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]])
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
// Broadcast up rank
// CHECK-NEXT: [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2}
// CHECK-NEXT: [[ARG_2:%.*]] = s32[2,3,4] parameter(2)
// CHECK-NEXT: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]])
%1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
// Broadcast up rank + degenerate broadcast
// CHECK-NEXT: [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2}
// CHECK-NEXT: [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]])
// CHECK-NEXT: [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2}
// CHECK-LABEL: ROOT
// CHECK-SAME: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]])
%2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
return %2 : tensor<2x3x4xi32>
}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
// CHECK: [[ARG:%.*]] = s32[4] parameter(0)
// CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] [[ARG]]), dimensions={3}
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
return %0 : tensor<1x2x3x4xi32>
}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
%result = "xla_hlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<1xf32>) -> tensor<1x10xf32>
return %result : tensor<1x10xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: [[ARG:%.*]] = f32[1] parameter(0)
// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] [[ARG]]), dimensions={0}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
%0 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%1 = call @callee(%0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %1 : tensor<4xi32>
}
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK: [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s32[4], [[ARG_2:.*]]: s32[4]) -> s32[4] {
// CHECK: %[[ARG_1]] = s32[4] parameter(0)
// CHECK: %[[ARG_2]] = s32[4] parameter(1)
// CHECK-LABEL: ROOT
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]])
// CHECK: [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] {
// CHECK: %[[ARG_3]] = s32[4] parameter(0)
// CHECK: %[[ARG_4]] = s32[4] parameter(1)
// CHECK-LABEL: ROOT
// CHECK-SAME: s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]])
// CHECK: ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] {
// CHECK: %[[ARG]] = s32[4] parameter(0)
// CHECK: [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]]
// CHECK-LABEL: ROOT
// CHECK-SAME: s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]]
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
%0:2 = call @callee(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>)
return %0#0, %0#1 : tensor<4xi32>, tensor<4xi32>
}
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%1 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0, %1 : tensor<4xi32>, tensor<4xi32>
}
// Get name of callee computation
// CHECK: [[CALLEE:%.*]] ({{.*}}) -> ({{.*}}) {
// CHECK-LABEL: ENTRY
// CHECK-SAME: [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) {
// CHECK: %[[ARG]] = s32[4] parameter(0)
// CHECK: [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]]
// CHECK: [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0
// CHECK: [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1
// CHECK-LABEL: ROOT
// CHECK-SAME: (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]])
// -----
// CHECK-LABEL: HloModule
func @main(%arg0 : tensor<5x2xf32>,
%arg1 : tensor<5x5xf32>,
%arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
%result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) {
dimension = 1 : i64
} : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
return %result : tensor<5x14xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG0:.*]] = f32[5,2] parameter(0)
// CHECK: %[[ARG1:.*]] = f32[5,5] parameter(1)
// CHECK: %[[ARG2:.*]] = f32[5,7] parameter(2)
// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1}
// -----
// CHECK-LABEL: HloModule
func @main() -> tensor<2x2x1x1xf32> {
// CHECK: constant.{{.*}} = s64[] constant(1)
%cst = constant dense<1> : tensor<i64>
// CHECK: constant.{{.*}} = f32[2,2,1,1]
// CHECK-SAME: { { /*i0=0*/ { /*i1=0*/ {1} }, { /*i1=1*/ {2} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {4} } } }
%cst_0 = constant dense<
[[[[1.000000e+00]], [[2.000000e+00]]], [[[3.000000e+00]], [[4.000000e+00]]]]
> : tensor<2x2x1x1xf32>
// CHECK: s32[1] constant({1})
%cst_1 = constant dense<1> : tensor<1xi32>
// CHECK: %[[C:.*]] = s32[] constant(1)
// CHECK: s32[10] broadcast(s32[] %[[C]])
%cst_2 = constant dense<1> : tensor<10xi32>
// CHECK: s32[4] constant({1, 2, 3, 4})
%cst_3 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: s32[2,2] constant({ { 1, 2 }, { 3, 4 } })
%cst_4 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
// CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
%cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
return %cst_0 : tensor<2x2x1x1xf32>
}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
%result = "xla_hlo.conv"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 3 : i64,
kernel_output_feature_dimension = 2 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
lhs_dilation = dense<1> : tensor<2xi64>,
padding = dense<2> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>
} : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32>
return %result : tensor<100x28x28x1xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG0:.*]] = f32[100,26,26,32] parameter(0)
// CHECK: %[[ARG1:.*]] = f32[3,3,1,32] parameter(1)
// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]),
// CHECK-SAME: window={size=3x3 pad=2_2x2_2},
// CHECK-SAME: dim_labels=b01f_01oi->b01f
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG:.*]] = s32[2] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]])
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
// CHECK-LABEL: ENTRY
// CHECK: [[ARG:%.*]] = s32[2] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]])
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
%0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
%1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
return %1 : tensor<10xf32>
}
// CHECK: %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[]
// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
// CHECK-SAME: replica_groups={{[{][{]}}0,2,4,6},{1,3,5,7{{[}][}]}}
// CHECK-SAME: to_apply=%[[SUM_COMPUTATION]]
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
// Simple einsum is lowered to HLO dot op.
// CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
%0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
return %0 : tensor<3x5xi32>
}
// -----
// CHECK-LABEL: HloModule
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
return %0 : tensor<i32>
}
// CHECK-LABEL: ENTRY
// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG0:.*]] = (f32[], s32[]) parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0
// -----
// CHECK-LABEL: HloModule
func @main() -> tensor<1x10xf32> {
%result = "xla_hlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<1x10xf32>
return %result : tensor<1x10xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: ROOT %[[RESULT:.*]] = f32[1,10] iota(), iota_dimension=1
// -----
// CHECK-LABEL: HloModule
func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
%0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
return %0 : tensor<13x19xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: [[ARG:%.*]] = f32[4,6] parameter(0)
// CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1)
// CHECK-LABEL: ROOT
// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1
// -----
// CHECK-LABEL: HloModule
func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
%result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
%fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
%imax = "xla_hlo.max"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
"xla_hlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
}
// CHECK: %[[REGION:region_[0-9]+]]
// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[])
// CHECK: %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]])
// CHECK: %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]])
// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]])
// CHECK-LABEL: ENTRY
// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG1:.*]]: s32[1,10], [[ARG2:.*]]: f32[], [[ARG3:.*]]: s32[]) -> (f32[1], s32[1])
// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %[[ARG0]], s32[1,10] %[[ARG1]], f32[] %[[ARG2]], s32[] %[[ARG3]]), dimensions={1}, to_apply=%[[REGION]]
// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0
// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1
// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]])
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> {
%0 = xla_hlo.constant dense<-2147483648> : tensor<i32>
%1 = "xla_hlo.reduce_window"(%arg0, %0) ( {
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
%2 = xla_hlo.max %arg1, %arg2 : tensor<i32>
"xla_hlo.return"(%2) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>,
padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>,
base_dilations = dense<[1, 1, 1, 1]> : tensor<4xi64>,
window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<2x17x31x7xi32>, tensor<i32>) -> tensor<2x3x5x7xi32>
return %1 : tensor<2x3x5x7xi32>
}
// CHECK: %[[MAX_COMPUTATION:.*]] ([[ARG0:.*]]: s32[], [[ARG1:.*]]: s32[]) -> s32[]
// CHECK: ROOT %[[RESULT:.*]] = s32[] maximum(s32[] %[[ARG0]], s32[] %[[ARG1]])
// CHECK-LABEL: ENTRY
// CHECK-DAG: %[[ARG0:.*]] = s32[2,17,31,7] parameter(0)
// CHECK-DAG: %[[INIT:.*]] = s32[] constant(-2147483648)
// CHECK: ROOT %[[RESULT:.*]] = s32[2,5,8,7] reduce-window(s32[2,17,31,7] %[[ARG0]], s32[] %constant.2),
// CHECK-SAME: window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1},
// CHECK-SAME: to_apply=%[[MAX_COMPUTATION]]
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG0:.*]] = f32[2] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[1,2] reshape(f32[2] %[[ARG0]])
// -----
// CHECK-LABEL: HloModule
func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
%result = "xla_hlo.reverse"(%arg0) {
dimensions = dense<[1,2]> : tensor<2xi64>
} : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32>
return %result : tensor<10x11x12x13xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG0:.*]] = f32[10,11,12,13] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[10,11,12,13] reverse(f32[10,11,12,13] %[[ARG0]]), dimensions={1,2}
// -----
// CHECK-LABEL: HloModule
func @main() -> tensor<2x3x5xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%1 = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
%2 = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
%3 = "xla_hlo.rng_uniform"(%0, %1, %2) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
return %3 : tensor<2x3x5xf32>
}
// CHECK-LABEL: ENTRY
// CHECK-DAG: %[[A:.*]] = f32[] constant(0)
// CHECK-DAG: %[[B:.*]] = f32[] constant(1)
// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[A]], f32[] %[[B]]), distribution=rng_uniform
// -----
// CHECK-LABEL: HloModule
func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> {
%0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
}) {
scatter_dimension_numbers = {
update_window_dims = dense<[1]> : tensor<1xi64>,
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
index_vector_dim = 1 : i64
},
indices_are_sorted = true,
unique_indices = true
} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32>
return %0 : tensor<200x100x300xf32>
}
// CHECK: [[COMPUTATION:%.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
// CHECK-LABEL: ENTRY
// CHECK: [[VAL_1:%.*]] = f32[200,100,300] parameter(0)
// CHECK: [[VAL_2:%.*]] = s32[10,2] parameter(1)
// CHECK: [[VAL_3:%.*]] = f32[10,300] parameter(2)
// CHECK-LABEL: ROOT
// CHECK-SAME: f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[10,2] [[VAL_2]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]]
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK: %[[ARG0:.*]] = pred[] parameter(0)
// CHECK: %[[COND:.*]] = pred[2,3] broadcast(pred[] %[[ARG0]]), dimensions={}
// CHECK: %[[ARG1:.*]] = s32[2,3] parameter(1)
// CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2)
// CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]])
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_hlo.select_and_scatter"(%arg0, %arg1, %0) ( {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
%2 = "xla_hlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
%2 = xla_hlo.add %arg3, %arg4 : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) -> tensor<10x24x24x64xf32>
return %1 : tensor<10x24x24x64xf32>
}
// CHECK: %[[SELECT_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] {
// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE
// CHECK: %[[SCATTER_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] {
// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG0:.*]] = f32[10,24,24,64] parameter(0)
// CHECK: %[[ARG1:.*]] = f32[10,12,12,64] parameter(1)
// CHECK: %[[INIT:.*]] = f32[] constant(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[10,24,24,64]
// CHECK-SAME: select-and-scatter(f32[10,24,24,64] %[[ARG0]], f32[10,12,12,64] %[[ARG1]], f32[] %[[INIT]]),
// CHECK-SAME: window={size=1x2x2x1 stride=1x2x2x1},
// CHECK-SAME: select=%[[SELECT_COMPUTATION]], scatter=%[[SCATTER_COMPUTATION]]
// -----
// CHECK-LABEL: HloModule
func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
// CHECK-LABEL: ENTRY
// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0)
// CHECK-LABEL: ROOT
// CHECK-SAME: s32[1,2] slice(s32[3,4] [[ARG]]), slice={[1:2:1], [0:4:2]}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
// CHECK: [[ARG:%.*]] = s32[1,2,3,4] parameter(0)
// CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] [[ARG]]), dimensions={1,0,3,2}
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
return %0 : tensor<2x1x4x3xi32>
}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<f32>, %arg1 : tensor<i32>) -> tuple<tensor<f32>, tensor<i32>> {
%result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
return %result : tuple<tensor<f32>, tensor<i32>>
}
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG0:.*]] = f32[] parameter(0)
// CHECK: %[[ARG1:.*]] = s32[] parameter(1)
// CHECK: ROOT %[[RESULT:.*]] = (f32[], s32[]) tuple(f32[] %[[ARG0]], s32[] %[[ARG1]])
// -----
// CHECK-LABEL: HloModule
func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) {
// CHECK: [[ARG_F32:%.*]] = f32[4] parameter(0)
// CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]])
%expm1 = "xla_hlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]])
%log1p = "xla_hlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: [[ARG_I32:%.*]] = s32[4] parameter(1)
// CHECK: [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]])
%not = "xla_hlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
// CHECK: [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]])
%popcnt = "xla_hlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>
}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: [[VAL_1:%.*]] = pred[4] parameter(0)
// CHECK: [[VAL_2:%.*]] = pred[4] parameter(1)
%0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1>
// CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
return %0 : tensor<4xi1>
}
// -----
// CHECK-LABEL: HloModule
func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] {
// CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT
// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) {
// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]

View File

@ -1,10 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
return %0 : tensor<i32>
}
// CHECK-LABEL: ENTRY
// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}

View File

@ -1,10 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK-LABEL: main
// CHECK: %[[ARG0:.*]] = (f32[], s32[]) parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0

View File

@ -1,11 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main() -> tensor<1x10xf32> {
%result = "xla_hlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<1x10xf32>
return %result : tensor<1x10xf32>
}
// CHECK-LABEL:main
// CHECK: ROOT %[[RESULT:.*]] = f32[1,10] iota(), iota_dimension=1

View File

@ -1,12 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
%0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
return %0 : tensor<13x19xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: [[ARG:%.*]] = f32[4,6] parameter(0)
// CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1)
// CHECK-LABEL: ROOT
// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1

View File

@ -1,24 +0,0 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
%result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
%fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
%imax = "xla_hlo.max"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
"xla_hlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
}
// CHECK: %[[REGION:region_[0-9]+]]
// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[])
// CHECK: %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]])
// CHECK: %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]])
// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]])
// CHECK: ENTRY %main
// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG0:.*]]: s32[1,10], [[ARG0:.*]]: f32[], [[ARG0:.*]]: s32[]) -> (f32[1], s32[1])
// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %Arg_0.1, s32[1,10] %Arg_1.2, f32[] %Arg_2.3, s32[] %Arg_3.4), dimensions={1}, to_apply=%[[REGION]]
// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0
// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1
// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]])

Some files were not shown because too many files have changed in this diff Show More