This commit is contained in:
Stephan 2021-03-03 13:20:33 +01:00
commit e2bd33bc62
579 changed files with 13628 additions and 6710 deletions

View File

@ -50,7 +50,6 @@
# Feature and Third party library support options:
# xla: Build TF with XLA
# tpu: Build TF with TPU support
# using_cuda: CUDA is available to build system.
# cuda: Build with full cuda support.
# rocm: Build with AMD GPU support (rocm).
# mkl: Enable full mkl support.
@ -92,6 +91,7 @@
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds.
# release_gpu_linux_cuda_11_2: Toolchain and CUDA options for CUDA 11.2 Linux GPU builds.
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
@ -190,23 +190,22 @@ build:mkl_aarch64 -c opt
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
build:using_cuda --action_env TF_NEED_CUDA=1
build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda_base --@local_config_cuda//:enable_cuda
build:cuda_base --action_env TF_NEED_CUDA=1
build:cuda_base --crosstool_top=@local_config_cuda//crosstool:toolchain
# Enable the mlir generated GPU kernels only for cuda builds.
build --define=tensorflow_enable_mlir_generated_gpu_kernels=0
# This is a more specific option, so it takes precedence over the line above for cuda builds.
build:using_cuda --define=tensorflow_enable_mlir_generated_gpu_kernels=1
build:cuda_base --define=tensorflow_enable_mlir_generated_gpu_kernels=1
# This config refers to building CUDA op kernels with nvcc.
build:cuda --config=using_cuda
build:cuda --define=using_cuda_nvcc=true
build:cuda --config=cuda_base
build:cuda --@local_config_cuda//:cuda_compiler=nvcc
# This config refers to building CUDA op kernels with clang.
build:cuda_clang --config=using_cuda
build:cuda_clang --define=using_cuda_clang=true
build:cuda_clang --define=using_clang=true
build:cuda_clang --config=cuda_base
build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
build:cuda_clang --action_env TF_CUDA_CLANG=1
# dbg config, as a shorthand for '--config=opt -c dbg'
@ -430,6 +429,9 @@ build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platf
build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --@local_config_cuda//:enable_cuda
build:rbe_linux_cuda_base --@local_config_cuda//:cuda_compiler=nvcc
# TODO(csigg): those probably don't do anything because cuda_config is remote.
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7
@ -438,7 +440,6 @@ build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
@ -455,7 +456,6 @@ build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo
build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
@ -623,6 +623,13 @@ build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cud
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10"
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7"
build:release_gpu_linux_cuda_11_2 --config=release_gpu_linux
build:release_gpu_linux_cuda_11_2 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2"
build:release_gpu_linux_cuda_11_2 --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain
build:release_gpu_linux_cuda_11_2 --action_env=TF_CUDA_VERSION="11.2"
build:release_gpu_linux_cuda_11_2 --action_env=TF_CUDNN_VERSION="8.1"
# Address sanitizer
# CC=clang bazel build --config asan
build:asan --strip=never

View File

@ -119,6 +119,10 @@
`tf.config.experimental.get_memory_usage` in favor of this new function.
* Extended `tf.config.experimental.enable_tensor_float_32_execution` to
control Tensor-Float-32 evaluation in RNNs.
* Added a 'experimental_payloads' field to tf.errors.OpError and
its subclasses to support more detailed error reporting.
This is inspired from Abseil Status payloads:
https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h
* `tf.summary`:
* New `tf.summary.graph` allows manual write of TensorFlow graph
@ -140,7 +144,11 @@
`max_batch_size`. Previously, we issued a warning when the value of
`rewriter_config_template` is not None. We issued an error when the
value of `is_dynamic_op` is not True. We didn't use the value for
`max_batch_size` for building TensorRT engines.
`max_batch_size` for building TensorRT engines. Add parameters
`use_dynamic_shape` to enable dynamic shape support. The default is to
disable dynamic shape support. Add `dynamic_shape_profile_strategy`
for selecting a dynamic shape profile strategy. The default is profile
strategy is `Range`.
* Issue a warning when function get_tensorrt_rewriter_config is used.
* TF XLA

View File

@ -197,7 +197,19 @@ config_setting(
config_setting(
name = "windows",
values = {"cpu": "x64_windows"},
# Internal builds query the target OS.
# copybara:uncomment flag_values = {"//tools/cpp:cc_target_os": "windows"},
# OSS builds query the CPU type.
values = {"cpu": "x64_windows"}, # copybara:comment
visibility = ["//visibility:public"],
)
config_setting(
name = "msvc_cl_debug",
values = {
"compiler": "msvc-cl",
"compilation_mode": "dbg",
},
visibility = ["//visibility:public"],
)
@ -402,11 +414,12 @@ config_setting(
# Crosses between platforms and file system libraries not supported on those
# platforms due to limitations in nested select() statements.
config_setting(
selects.config_setting_group(
name = "with_cuda_support_windows_override",
define_values = {"using_cuda_nvcc": "true"},
values = {"cpu": "x64_windows"},
visibility = ["//visibility:public"],
match_all = [
":using_cuda_nvcc",
":windows",
],
)
config_setting(
@ -470,9 +483,46 @@ selects.config_setting_group(
],
)
config_setting(
# Config setting that is satisfied when TensorFlow is being built with CUDA
# support through e.g. `--config=cuda` (or `--config=cuda_clang` in OSS).
alias(
name = "is_cuda_enabled",
actual = if_oss(
"@local_config_cuda//:is_cuda_enabled",
"@local_config_cuda//cuda:using_clang",
),
)
# Config setting that is satisfied when CUDA device code should be compiled
# with clang. It does not imply that CUDA support has been enabled.
alias(
name = "is_cuda_compiler_clang",
actual = if_oss(
"@local_config_cuda//:is_cuda_compiler_clang",
"@local_config_cuda//cuda:TRUE",
),
)
# Config setting that is satisfied when CUDA device code should be compiled
# with nvcc. It does not imply that CUDA support has been enabled.
alias(
name = "is_cuda_compiler_nvcc",
actual = if_oss(
"@local_config_cuda//:is_cuda_compiler_nvcc",
"@local_config_cuda//cuda:FALSE",
),
)
# Config setting whether TensorFlow is built with CUDA support using clang.
alias(
name = "using_cuda_clang",
define_values = {"using_cuda_clang": "true"},
actual = "@local_config_cuda//cuda:using_clang",
)
# Config setting whether TensorFlow is built with CUDA support using nvcc.
alias(
name = "using_cuda_nvcc",
actual = "@local_config_cuda//cuda:using_nvcc",
)
# Config setting to use in select()s to distinguish open source build from
@ -493,12 +543,12 @@ bool_setting(
visibility = ["//visibility:private"],
)
config_setting(
selects.config_setting_group(
name = "using_cuda_clang_with_dynamic_build",
define_values = {
"using_cuda_clang": "true",
"framework_shared_object": "true",
},
match_all = [
":using_cuda_clang",
":framework_shared_object",
],
)
selects.config_setting_group(
@ -519,19 +569,6 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "using_cuda_nvcc",
define_values = {"using_cuda_nvcc": "true"},
)
config_setting(
name = "using_cuda_nvcc_with_dynamic_build",
define_values = {
"using_cuda_nvcc": "true",
"framework_shared_object": "true",
},
)
selects.config_setting_group(
name = "build_oss_using_cuda_nvcc",
match_all = [
@ -559,15 +596,6 @@ config_setting(
visibility = ["//visibility:public"],
)
# This flag is defined for select statements that match both
# on 'windows' and 'api_version_2'. In this case, bazel requires
# having a flag which is a superset of these two.
config_setting(
name = "windows_and_api_version_2",
define_values = {"tf_api_version": "2"},
values = {"cpu": "x64_windows"},
)
# This flag enables experimental MLIR support.
config_setting(
name = "with_mlir_support",
@ -934,6 +962,11 @@ tf_cc_shared_object(
"-z defs",
"-Wl,--version-script,$(location //tensorflow:tf_version_script.lds)",
],
}) + select({
"//tensorflow:msvc_cl_debug": [
"/DEBUG:FASTLINK",
],
"//conditions:default": [],
}),
per_os_targets = True,
soversion = VERSION,
@ -952,7 +985,6 @@ tf_cc_shared_object(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/cc:scope",
"//tensorflow/cc/profiler",
"//tensorflow/core:tensorflow",
],
)

View File

@ -283,7 +283,6 @@ tf_cuda_cc_test(
"//tensorflow/c/experimental/gradients:not_differentiable",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -909,7 +908,6 @@ tf_cuda_cc_test(
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
@ -935,7 +933,6 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -974,7 +971,6 @@ tf_cc_test(
":custom_device_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"

View File

@ -79,6 +79,7 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status) {
assert(0);
break;
}
tf_status->status.ReplaceAllPayloads(status.GetAllPayloads());
}
Status StatusFromTF_Status(const TF_Status* tf_status) {

View File

@ -24,6 +24,8 @@ namespace {
TEST(StatusHelper, TestStatusHelper) {
TF_Status* s = TF_NewStatus();
Status cc_status(errors::InvalidArgument("some error"));
cc_status.SetPayload("key1", "value1");
cc_status.SetPayload("key2", "value2");
Set_TF_Status_from_Status(s, cc_status);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
ASSERT_EQ(std::string("some error"), TF_Message(s));
@ -32,6 +34,9 @@ TEST(StatusHelper, TestStatusHelper) {
ASSERT_FALSE(another_cc_status.ok());
ASSERT_EQ(std::string("some error"), another_cc_status.error_message());
ASSERT_EQ(error::INVALID_ARGUMENT, another_cc_status.code());
// Ensure the payloads are not lost during conversions
ASSERT_EQ(cc_status.GetPayload("key1"), another_cc_status.GetPayload("key1"));
ASSERT_EQ(cc_status.GetPayload("key2"), another_cc_status.GetPayload("key2"));
TF_DeleteStatus(s);
}

View File

@ -6,7 +6,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load(
"//tensorflow:tensorflow.bzl",
"cc_library_with_android_deps",
"tf_cc_binary",
"tf_cc_test",
"tf_copts",
"transitive_hdrs",
@ -748,36 +747,6 @@ tf_gen_op_wrappers_cc(
],
)
tf_cc_binary(
name = "tutorials_example_trainer",
srcs = ["tutorials/example_trainer.cc"],
copts = tf_copts(),
linkopts = select({
"//tensorflow:windows": [],
"//tensorflow:macos": [
"-lm",
"-lpthread",
],
"//tensorflow:ios": [
"-lm",
"-lpthread",
],
"//conditions:default": [
"-lm",
"-lpthread",
"-lrt",
],
}),
deps = [
":cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)
cc_library(
name = "queue_runner",
srcs = ["training/queue_runner.cc"],
@ -856,7 +825,6 @@ transitive_hdrs(
":queue_runner",
":remote_fused_graph_ops",
":scope",
"//tensorflow/cc/profiler",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/cc/saved_model:loader",
"//tensorflow/cc/saved_model:reader",

View File

@ -1,40 +0,0 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
tf_cuda_cc_test(
name = "profiler_test",
srcs = ["profiler_test.cc"],
tags = [
"no_gpu", # b/77649654
"no_rocm", # stream level tracing not supported on ROCm
],
deps = [
":profiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "profiler",
srcs = ["profiler.cc"],
hdrs = ["profiler.h"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler:protos_all_cc",
"//tensorflow/core/profiler:tfprof_options",
"//tensorflow/core/profiler/internal:tfprof_stats",
],
)

View File

@ -1,57 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/profiler/profiler.h"
namespace tensorflow {
namespace tfprof {
Profiler::Profiler(const GraphDef& graph) {
std::unique_ptr<GraphDef> graph_ptr(new GraphDef());
*graph_ptr = graph;
stats_.reset(new TFStats(std::move(graph_ptr), nullptr, nullptr, nullptr));
}
void Profiler::AddStep(int64 step, const RunMetadata& run_meta) {
std::unique_ptr<RunMetadata> run_meta_ptr(new RunMetadata());
*run_meta_ptr = run_meta;
stats_->AddRunMeta(step, std::move(run_meta_ptr));
}
GraphNodeProto Profiler::ProfileGraph(const Options& options) {
stats_->BuildView(kCmds[1]);
return stats_->ShowGraphNode(kCmds[1], options);
}
GraphNodeProto Profiler::ProfileNameScope(const Options& options) {
stats_->BuildView(kCmds[0]);
return stats_->ShowGraphNode(kCmds[0], options);
}
MultiGraphNodeProto Profiler::ProfileOperations(const Options& options) {
stats_->BuildView(kCmds[3]);
return stats_->ShowMultiGraphNode(kCmds[3], options);
}
Status Profiler::SerializeToString(string* content) {
if (!content) {
return Status(error::Code::INVALID_ARGUMENT,
"Cannot use null string pointer for SerializeToString.");
}
stats_->SerializeToString(content);
return Status::OK();
}
} // namespace tfprof
} // namespace tensorflow

View File

@ -1,100 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_PROFILER_PROFILER_H_
#define TENSORFLOW_CC_PROFILER_PROFILER_H_
#include <memory>
#include <string>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/profiler/internal/tfprof_stats.h"
#include "tensorflow/core/profiler/tfprof_options.h"
#include "tensorflow/core/profiler/tfprof_output.pb.h"
namespace tensorflow {
namespace tfprof {
/// @addtogroup core
/// @{
/// A `Profiler` object lets the caller profile the execution of a graph.
///
/// Example:
/// // First build a graph and run tracing.
/// Scope root = Scope::NewRootScope();
/// auto a = Placeholder(root, DT_INT32);
/// auto c = Add(root, a, {41});
///
/// ClientSession session(root);
/// std::vector<Tensor> outputs;
/// RunOptions run_options;
/// run_options.set_trace_level(RunOptions::FULL_TRACE);
/// RunMetadata run_meta;
/// Status s = session.Run(run_options, { {a, {1}} }, {c}, &outputs,
/// &run_meta);
/// if (!s.ok()) { ... }
///
/// // Then create profiler to do profiling.
/// GraphDef graph;
/// root.ToGraphDef(&graph);
/// Profiler profiler(graph);
/// profiler.AddStep(0, run_meta);
/// Options opts = ... // TODO(xpan): Support option building API.
/// MultiGraphNodeProto r = profiler.ProfileOperations(opts);
///
class Profiler {
public:
/// `graph` is the model's GraphDef.
explicit Profiler(const GraphDef& graph);
/// Adds tracing information `run_meta` to profiler. A `run_meta` is
/// generated by a TensorFlow session run call. `step` is the key
/// to the `run_meta`. When calling ProfileXXX methods, caller can specify
/// `step` in `options` to selectively profile the corresponding `run_meta`.
/// Multiple different `run_meta` can be keyed by the same `step` in order
/// to group them together.
void AddStep(int64 step, const RunMetadata& run_meta);
/// Profiles the model by organizing nodes in graph structure.
/// Each node is an op and the nodes are connected by the op inputs/outputs.
GraphNodeProto ProfileGraph(const Options& options);
/// Profiles the model by organizing nodes in name scope structure.
/// Each node is an op, and nodes are organized by the ops' name
/// scope, similar to a file system tree.
/// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2.
GraphNodeProto ProfileNameScope(const Options& options);
/// Profiles the model by organizing nodes by operation types.
/// Each node is an operation type (e.g. Conv2D or MatMul), containing all
/// ops belonging to that type in the model.
MultiGraphNodeProto ProfileOperations(const Options& options);
/// Serialize the profile content (ProfileProto) into a binary string,
/// User can write the string to file for offline analysis by
/// tfprof command-line tools or graphical user interface.
Status SerializeToString(string* content);
private:
std::unique_ptr<TFStats> stats_;
};
/// @}
} // namespace tfprof
} // namespace tensorflow
#endif // TENSORFLOW_CC_PROFILER_PROFILER_H_

View File

@ -1,177 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/test.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace tfprof {
class ProfilerTest : public ::testing::Test {
protected:
ProfilerTest() {}
};
GraphDef CreateGraphDef() {
Scope root = Scope::NewRootScope();
auto a = ops::Const<float>(root, {{3, 2}, {-1, 0}});
auto x = ops::Const(root.WithOpName("x"), {{1.f}, {1.f}});
auto y = ops::MatMul(root.WithOpName("y"), a, x);
auto y2 = ops::Square(root, y);
auto y2_sum = ops::Sum(root, y2, 0);
auto y_norm = ops::Sqrt(root, y2_sum);
auto y_div = ops::Div(root.WithOpName("y_normalized"), y, y_norm);
GraphDef def;
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
}
Options Default() {
Options opts(1000, /* max_depth */
0, /* min_bytes */
0, /* min_peak_bytes */
0, /* min_residual_bytes */
0, /* min_output_bytes */
0, /* min_micros */
0, /* min_accelerator_micros */
0, /* min_cpu_micros */
0, /* min_params */
0, /* min_float_ops */
0, /* min_occurrence */
0, /* step */
"name", /* order_by */
{".*"}, /* account_type_regexes */
{".*"}, /* start_name_regexes */
{}, /* trim_name_regexes */
{".*"}, {}, /* hide_name_regexes */
false, /* account_displayed_op_only */
{"micros"}, /* select */
{"none"}, /* output_type */
{});
return opts;
}
template <typename T>
const T* ExtractNode(const T& pb, const string& name) {
if (pb.name() == name) {
return &pb;
}
for (const T& c : pb.children()) {
const T* ret = ExtractNode(c, name);
if (ret) return ret;
}
return nullptr;
}
TEST_F(ProfilerTest, Basics) {
SessionOptions options;
options.config.set_allow_soft_placement(true);
std::unique_ptr<Session> session(NewSession(options));
GraphDef def = CreateGraphDef();
if (options.target.empty()) {
graph::SetDefaultDevice("/gpu:0", &def);
}
TF_CHECK_OK(session->Create(def));
Tensor x(DT_FLOAT, TensorShape({2, 1}));
auto x_flat = x.flat<float>();
x_flat.setRandom();
Eigen::Tensor<float, 0, Eigen::RowMajor> inv_norm =
x_flat.square().sum().sqrt().inverse();
x_flat = x_flat * inv_norm();
std::vector<Tensor> outputs;
RunOptions run_options;
run_options.set_trace_level(RunOptions::FULL_TRACE);
RunMetadata run_metadata;
outputs.clear();
Profiler profiler(def);
for (int i = 0; i < 2; ++i) {
TF_CHECK_OK(session->Run(run_options, {{"x", x}}, {"y:0", "y_normalized:0"},
{}, &outputs, &run_metadata));
profiler.AddStep(i, run_metadata);
CHECK_EQ(size_t{2}, outputs.size());
}
std::vector<DeviceAttributes> resp;
TF_CHECK_OK(session->ListDevices(&resp));
bool has_gpu = false;
for (const auto& dev : resp) {
if (dev.device_type() == "GPU") {
has_gpu = true;
}
}
GraphNodeProto ret = profiler.ProfileNameScope(Default());
const GraphNodeProto* matmul = ExtractNode(ret, "y");
EXPECT_TRUE(matmul);
EXPECT_GT(matmul->exec_micros(), 0);
if (has_gpu) {
EXPECT_GT(matmul->accelerator_exec_micros(), 0);
} else {
EXPECT_EQ(matmul->accelerator_exec_micros(), 0);
}
const GraphNodeProto* square = ExtractNode(ret, "Square");
EXPECT_TRUE(square);
EXPECT_GT(square->exec_micros(), 0);
if (has_gpu) {
EXPECT_GT(square->accelerator_exec_micros(), 0);
} else {
EXPECT_EQ(square->accelerator_exec_micros(), 0);
}
Options opts2 = Default();
opts2.output_type = "timeline";
string timeline_file = io::JoinPath(testing::TmpDir(), "timeline");
opts2.output_options["outfile"] = timeline_file;
GraphNodeProto ret2 = profiler.ProfileGraph(opts2);
string s;
TF_CHECK_OK(ReadFileToString(Env::Default(), timeline_file + "_0", &s));
EXPECT_TRUE(s.find("Square") != s.npos);
MultiGraphNodeProto ret3 = profiler.ProfileOperations(Default());
const MultiGraphNodeProto* matmul2 = ExtractNode(ret3, "MatMul");
EXPECT_TRUE(matmul2);
EXPECT_GT(matmul2->exec_micros(), 0);
if (has_gpu) {
EXPECT_GT(matmul2->accelerator_exec_micros(), 0);
} else {
EXPECT_EQ(matmul2->accelerator_exec_micros(), 0);
}
TF_CHECK_OK(session->Close());
}
} // namespace tfprof
} // namespace tensorflow

View File

@ -1,234 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdio>
#include <functional>
#include <string>
#include <vector>
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
using tensorflow::string;
using tensorflow::int32;
namespace tensorflow {
namespace example {
struct Options {
int num_concurrent_sessions = 1; // The number of concurrent sessions
int num_concurrent_steps = 10; // The number of concurrent steps
int num_iterations = 100; // Each step repeats this many times
bool use_gpu = false; // Whether to use gpu in the training
};
// A = [3 2; -1 0]; x = rand(2, 1);
// We want to compute the largest eigenvalue for A.
// repeat x = y / y.norm(); y = A * x; end
GraphDef CreateGraphDef() {
// TODO(jeff,opensource): This should really be a more interesting
// computation. Maybe turn this into an mnist model instead?
Scope root = Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
// A = [3 2; -1 0]. Using Const<float> means the result will be a
// float tensor even though the initializer has integers.
auto a = Const<float>(root, {{3, 2}, {-1, 0}});
// x = [1.0; 1.0]
auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}});
// y = A * x
auto y = MatMul(root.WithOpName("y"), a, x);
// y2 = y.^2
auto y2 = Square(root, y);
// y2_sum = sum(y2). Note that you can pass constants directly as
// inputs. Sum() will automatically create a Const node to hold the
// 0 value.
auto y2_sum = Sum(root, y2, 0);
// y_norm = sqrt(y2_sum)
auto y_norm = Sqrt(root, y2_sum);
// y_normalized = y ./ y_norm
Div(root.WithOpName("y_normalized"), y, y_norm);
GraphDef def;
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
}
string DebugString(const Tensor& x, const Tensor& y) {
CHECK_EQ(x.NumElements(), 2);
CHECK_EQ(y.NumElements(), 2);
auto x_flat = x.flat<float>();
auto y_flat = y.flat<float>();
// Compute an estimate of the eigenvalue via
// (x' A x) / (x' x) = (x' y) / (x' x)
// and exploit the fact that x' x = 1 by assumption
Eigen::Tensor<float, 0, Eigen::RowMajor> lambda = (x_flat * y_flat).sum();
return strings::Printf("lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]",
lambda(), x_flat(0), x_flat(1), y_flat(0), y_flat(1));
}
void ConcurrentSteps(const Options* opts, int session_index) {
// Creates a session.
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
GraphDef def = CreateGraphDef();
if (options.target.empty()) {
graph::SetDefaultDevice(opts->use_gpu ? "/device:GPU:0" : "/cpu:0", &def);
}
TF_CHECK_OK(session->Create(def));
// Spawn M threads for M concurrent steps.
const int M = opts->num_concurrent_steps;
std::unique_ptr<thread::ThreadPool> step_threads(
new thread::ThreadPool(Env::Default(), "trainer", M));
for (int step = 0; step < M; ++step) {
step_threads->Schedule([&session, opts, session_index, step]() {
// Randomly initialize the input.
Tensor x(DT_FLOAT, TensorShape({2, 1}));
auto x_flat = x.flat<float>();
x_flat.setRandom();
Eigen::Tensor<float, 0, Eigen::RowMajor> inv_norm =
x_flat.square().sum().sqrt().inverse();
x_flat = x_flat * inv_norm();
// Iterations.
std::vector<Tensor> outputs;
for (int iter = 0; iter < opts->num_iterations; ++iter) {
outputs.clear();
TF_CHECK_OK(
session->Run({{"x", x}}, {"y:0", "y_normalized:0"}, {}, &outputs));
CHECK_EQ(size_t{2}, outputs.size());
const Tensor& y = outputs[0];
const Tensor& y_norm = outputs[1];
// Print out lambda, x, and y.
std::printf("%06d/%06d %s\n", session_index, step,
DebugString(x, y).c_str());
// Copies y_normalized to x.
x = y_norm;
}
});
}
// Delete the threadpool, thus waiting for all threads to complete.
step_threads.reset(nullptr);
TF_CHECK_OK(session->Close());
}
void ConcurrentSessions(const Options& opts) {
// Spawn N threads for N concurrent sessions.
const int N = opts.num_concurrent_sessions;
// At the moment our Session implementation only allows
// one concurrently computing Session on GPU.
CHECK_EQ(1, N) << "Currently can only have one concurrent session.";
thread::ThreadPool session_threads(Env::Default(), "trainer", N);
for (int i = 0; i < N; ++i) {
session_threads.Schedule(std::bind(&ConcurrentSteps, &opts, i));
}
}
} // end namespace example
} // end namespace tensorflow
namespace {
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
int32* dst) {
if (absl::ConsumePrefix(&arg, flag) && absl::ConsumePrefix(&arg, "=")) {
char extra;
return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
}
return false;
}
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
bool* dst) {
if (absl::ConsumePrefix(&arg, flag)) {
if (arg.empty()) {
*dst = true;
return true;
}
if (arg == "=true") {
*dst = true;
return true;
} else if (arg == "=false") {
*dst = false;
return true;
}
}
return false;
}
} // namespace
int main(int argc, char* argv[]) {
tensorflow::example::Options opts;
std::vector<char*> unknown_flags;
for (int i = 1; i < argc; ++i) {
if (string(argv[i]) == "--") {
while (i < argc) {
unknown_flags.push_back(argv[i]);
++i;
}
break;
}
if (ParseInt32Flag(argv[i], "--num_concurrent_sessions",
&opts.num_concurrent_sessions) ||
ParseInt32Flag(argv[i], "--num_concurrent_steps",
&opts.num_concurrent_steps) ||
ParseInt32Flag(argv[i], "--num_iterations", &opts.num_iterations) ||
ParseBoolFlag(argv[i], "--use_gpu", &opts.use_gpu)) {
continue;
}
fprintf(stderr, "Unknown flag: %s\n", argv[i]);
return -1;
}
// Passthrough any unknown flags.
int dst = 1; // Skip argv[0]
for (char* f : unknown_flags) {
argv[dst++] = f;
}
argv[dst++] = nullptr;
argc = static_cast<int>(unknown_flags.size() + 1);
tensorflow::port::InitMain(argv[0], &argc, &argv);
tensorflow::example::ConcurrentSessions(opts);
}

View File

@ -297,6 +297,7 @@ cc_library(
# Public visibility is needed for external TF/XLA backends.
visibility = ["//visibility:public"],
deps = XLA_DEVICE_DEPS + [":xla_compilation_cache"],
alwayslink = 1,
)
cc_library(
@ -342,6 +343,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@ -355,7 +357,9 @@ cc_library(
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/protobuf:for_core_protos_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -292,6 +292,37 @@ MlirCommonFlags* GetMlirCommonFlags() {
return mlir_flags;
}
ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
absl::optional<const ConfigProto> config_proto) {
// TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs
// can enable/disable the graph via tf_mlir_enable_mlir_bridge.
auto tf_mlir_enable_mlir_bridge =
GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
if (tf_mlir_enable_mlir_bridge !=
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) {
return tf_mlir_enable_mlir_bridge;
}
// If a ConfigProto was not passed in, we can assume the caller is
// checking if TF2 graph should have the bridge enabled / disabled. In that
// case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to
// return UNSPECIFIED here.
if (!config_proto.has_value()) {
return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
}
// TF1 graphs that do override Session's ConfigProto and set
// ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not
// update tf_mlir_enable_mlir_bridge so check their values.
// ConfigProto's enable_mlir_bridge defaults to false so only respect it
// when it is true.
if (config_proto.value().experimental().enable_mlir_bridge()) {
return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
}
return config_proto.value().experimental().mlir_bridge_rollout();
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/command_line_flags.h"
@ -156,6 +157,11 @@ GetIntroduceFloatingPointJitterPassFlags();
MlirCommonFlags* GetMlirCommonFlags();
// Returns the effective MLIR bridge rollout state based on the flags and the
// optional configuration.
ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
absl::optional<const ConfigProto> config_proto);
// Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
//

View File

@ -621,3 +621,6 @@ func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
return %1 : tensor<8xi32>
}
```
### `-tf-verify-for-export`: Verify module is suitable for export back to TF Graph
Verifies whether all functions in module are of single tf_executor.graph and
each tf_executor.island in tf_executor.graph only has a single op.

View File

@ -708,4 +708,49 @@ def HLOClient_BroadcastSelectOp : HLOClient_Op<
}];
}
//===----------------------------------------------------------------------===//
// Helper ops
//===----------------------------------------------------------------------===//
def HLOClient_MinimumBroadcastShapesOp :
HLOClient_Op<"minimum_broadcast_shapes", [NoSideEffect]> {
string summary = "Minimizes the rank of two or more shapes to be broadcasted";
string description = [{
Given two or more 1D tensors representing shapes, returns one 1D tensor for
each operand, where operand `i` corresponds to output `i`.
The returned tensors have the property that they specify a shape which is a
reshape of the corresponding input shape, and the broadcasted output shape
(using shape::BroadcastOp) of the returned shapes is a reshape of the
broadcasted output shape of the input shapes. Among all possibilities with
this property, the one is chosen which minimizes the rank of each returned
shape.
The general idea of this op is that it can be used for ops which have a
broadcasting semantic to operate on shapes with a possibly smaller rank
while preserving equivalence of the computed values. After computing the
result of the op using reshaped operands, the result can be reshaped to the
result that would have been originally computed.
Here is an example with two input shapes:
```mlir
chlo.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1],
[1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3]
```
The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the
broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are
reshapes of each other, and also each output is a reshape of the
corresponding input.
}];
let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes);
let results = (outs Variadic<1DTensorOf<[Index]>>:$results);
let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)";
}
#endif // CHLO_OPS

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -888,6 +888,11 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate",
let hasCanonicalizer = 1;
let hasFolder = 1;
let extraClassDeclaration = [{
static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r) {
return succeeded(mlir::verifyCompatibleShapes(l, r));
}
}];
}
def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",

View File

@ -337,6 +337,26 @@ static LogicalResult Verify(ConstantLikeOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// MinimumBroadcastShapesOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(MinimumBroadcastShapesOp op) {
// Check that the number of operands matches the number of outputs.
unsigned result_shapes_count = op.results().size();
unsigned operand_shapes_count = op.shapes().size();
if (operand_shapes_count != result_shapes_count) {
return op.emitOpError()
<< "number of operand shapes (" << operand_shapes_count
<< ") does not match number of result shapes ("
<< result_shapes_count << ")";
}
if (operand_shapes_count < 2) {
return op.emitOpError() << "number of operand shapes ("
<< operand_shapes_count << ") should be >= 2";
}
return success();
}
LogicalResult ConstantLikeOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,

View File

@ -51,6 +51,7 @@ struct ChloLegalizeToHloPass
conversionTarget.addLegalDialect<
MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect,
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
if (broadcast_only_) {
chlo::PopulateChloBroadcastingPatterns(&getContext(),

View File

@ -424,7 +424,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
buffer_args.push_back(InsertAlloc(loc, result, &rewriter));
}
auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
op.getAttrs());
op->getAttrs());
// Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
@ -607,8 +607,9 @@ struct HloLegalizeToLhlo
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateFuncOpTypeConversionPattern(patterns, &context, converter);
populateCallOpTypeConversionPattern(patterns, &context, converter);
populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
patterns, &context, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, &context,
converter);
populateReturnOpTypeConversionPattern(patterns, &context, converter);
populateEliminateBufferizeMaterializationsPatterns(&context, converter,
patterns);

View File

@ -517,14 +517,16 @@ class HloDynamicBroadcastInDimConverter
auto shape_type = shape.getType().cast<RankedTensorType>();
int64_t result_rank = shape_type.getDimSize(0);
auto result_type = op.getType().dyn_cast<RankedTensorType>();
if (!result_type) return failure();
SmallVector<Value, 2> dyn_dims;
Location loc = op.getLoc();
for (int i = 0; i < result_rank; ++i) {
if (!result_type.isDynamicDim(i)) continue;
Value index = rewriter.create<ConstantIndexOp>(loc, i);
dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
}
auto result_type = op.getType().dyn_cast<RankedTensorType>();
if (!result_type) return failure();
int64_t nloops = result_type.getRank();
Value init = rewriter.create<linalg::InitTensorOp>(
@ -1146,8 +1148,7 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
return dyn_shape;
}
template <typename InputElType, int input_bit_width, typename OutputElType,
int output_bit_width, DotOperationType op_type, typename LinalgOp>
template <DotOperationType op_type, typename LinalgOp>
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
public:
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
@ -1157,28 +1158,13 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
return failure();
}
if (GetDotOperationType(op) != op_type) return failure();
mhlo::DotOp::Adaptor adaptor(args);
auto lhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
auto rhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
return failure();
}
Location loc = op.getLoc();
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
if (!output_el_type.isa<OutputElType>() ||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
return failure();
}
if (GetDotOperationType(op) != op_type) return failure();
Location loc = op.getLoc();
auto zero_attr = rewriter.getZeroAttr(output_el_type);
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
@ -1205,8 +1191,6 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
return dyn_shape;
}
template <typename InputElType, int input_bit_width, typename OutputElType,
int output_bit_width, typename LinalgOp>
class DotGeneralOpOnTensorsConversion
: public OpConversionPattern<mhlo::DotGeneralOp> {
public:
@ -1245,23 +1229,10 @@ class DotGeneralOpOnTensorsConversion
}
mhlo::DotGeneralOp::Adaptor adaptor(args);
auto lhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
auto rhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
return failure();
}
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
if (!output_el_type.isa<OutputElType>() ||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
return failure();
}
Location loc = op.getLoc();
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
auto zero_attr = rewriter.getZeroAttr(output_el_type);
@ -1269,7 +1240,7 @@ class DotGeneralOpOnTensorsConversion
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
Operation* linalg_op = rewriter.create<LinalgOp>(
Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
loc, /*resultTensorTypes=*/TypeRange{op.getType()},
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
/*outputBuffers=*/ValueRange{zero_tensor});
@ -1476,9 +1447,14 @@ struct NormalConvOpOnTensorsConversion
// The output shape is N spatial_dims F.
SmallVector<Value, 8> dyn_sizes;
for (int64_t i = 0, e = rank - 1; i < e; ++i) {
if (!result_type.isDynamicDim(i)) continue;
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, i));
if (result_type.isDynamicDim(0)) {
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, 0));
}
for (int64_t i = 1, e = rank - 1; i < e; ++i) {
if (result_type.isDynamicDim(i)) {
return rewriter.notifyMatchFailure(
op, "expected output spatial dims to be static shapes");
}
}
if (result_type.isDynamicDim(rank - 1)) {
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
@ -1702,49 +1678,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
ReverseConverter<mhlo::ReverseOp, false>,
SliceConverter<mhlo::SliceOp, false>,
TransposeConverter<mhlo::TransposeOp, false>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI32I32I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI32I32I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI32I32I32Op>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
DotOperationType::kMatrixMatrix,
DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
linalg::MatmulOp>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
DotOperationType::kMatrixVector,
DotOpOnTensorsConversion<DotOperationType::kMatrixVector,
linalg::MatvecOp>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
DotOperationType::kVectorDot, linalg::DotOp>,
DotGeneralOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
linalg::BatchMatmulI8I8I32Op>,
DotGeneralOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
linalg::BatchMatmulI16I16I32Op>,
DotGeneralOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
linalg::BatchMatmulI32I32I32Op>,
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
linalg::BatchMatmulOp>,
DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
DotGeneralOpOnTensorsConversion,
NormalConvOpOnTensorsConversion,
ReduceOnTensorsConversion,
PadOpOnTensorsConversion>(context);

View File

@ -126,7 +126,7 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
Type flatResultTy =
RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
Value flatResult =
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op->getAttrs());
// Restore original shape.
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
@ -192,7 +192,7 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
rhs_is_scalar ? rhs : reshaped};
Value computed = rewriter.create<ChloOpTy>(
loc, TypeRange{RankedTensorType::get({-1}, result_element_type)},
new_operands, op.getAttrs());
new_operands, op->getAttrs());
// Reshape the result back into an unranked tensor.
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
@ -223,9 +223,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
}
static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op,
Value value, int targeted_rank) {
Value shape, int targeted_rank) {
auto loc = op.getLoc();
Value shape = builder.create<shape::ShapeOfOp>(loc, value);
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, builder.getIndexType());
@ -246,6 +245,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder,
ChloOpTy op,
ValueRange operands,
ValueRange operand_shapes,
int targeted_rank) {
auto loc = op.getLoc();
SmallVector<Value, 2> reshaped_operands;
@ -253,10 +253,12 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
targeted_rank, RankedTensorType::kDynamicSize);
for (Value operand : operands) {
for (auto it : llvm::zip(operands, operand_shapes)) {
Value operand, shape;
std::tie(operand, shape) = it;
// Handle shape broadcasting and inference.
Value extended_operand_casted =
createBroadcastToKnownRank(if_builder, op, operand, targeted_rank);
createBroadcastToKnownRank(if_builder, op, shape, targeted_rank);
// 1. Reshape operands to the given rank (with the same number of
// elements)
@ -278,7 +280,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
auto result_type =
RankedTensorType::get(dynamic_dimensions, result_element_type);
Value result = if_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, reshaped_operands, op.getAttrs());
loc, ArrayRef<Type>{result_type}, reshaped_operands, op->getAttrs());
Value reshaped_result = if_builder.create<tensor::CastOp>(
loc, UnrankedTensorType::get(result_element_type), result);
if_builder.create<scf::YieldOp>(loc, reshaped_result);
@ -290,13 +292,37 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
ValueRange operands) {
auto loc = op.getLoc();
// Find the larger rank of the operands.
// Get the minimum broadcast shapes of the operands.
SmallVector<Value> shapes;
shapes.reserve(operands.size());
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value greater_rank;
for (Value operand : operands) {
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
shapes.push_back(shape);
}
auto broadcast_shape = rewriter.create<shape::BroadcastOp>(
loc, extent_tensor_type, shapes, nullptr);
SmallVector<Type> result_types(shapes.size(), extent_tensor_type);
auto reduced_shapes =
rewriter
.create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes)
.results();
SmallVector<Value> reshaped_operands;
reshaped_operands.reserve(operands.size());
for (auto it : llvm::zip(operands, reduced_shapes)) {
Value operand;
Value reduced_shape;
std::tie(operand, reduced_shape) = it;
auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, operand.getType(), operand, reduced_shape);
reshaped_operands.push_back(reshaped_operand);
}
// Find the largest rank of the operands.
Value greater_rank;
for (Value shape : reduced_shapes) {
Value rank =
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
if (!greater_rank) {
@ -314,17 +340,19 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
rewriter, op, greater_rank, 1);
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1);
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
reduced_shapes, 1);
// Put each subsequent rank specialization inside the else statement of the
// previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
constexpr int kMaxRankSpecialization = 6;
constexpr int kMaxRankSpecialization = 5;
for (int i = 2; i < kMaxRankSpecialization; i++) {
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
else_builder, op, greater_rank, i);
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
createRankSpecializedBroadcastAndOp(if_builder, op, operands, i);
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
reduced_shapes, i);
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
}
@ -336,12 +364,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
kMaxRankSpecialization),
"Input for dynamic binary op lowering was of a rank greater than " +
std::to_string(kMaxRankSpecialization));
// Add the rank 6 specialization to the innermost else block.
createRankSpecializedBroadcastAndOp(else_builder, op, operands,
kMaxRankSpecialization);
// Add the rank 5 specialization to the innermost else block.
createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands,
reduced_shapes, kMaxRankSpecialization);
// Return the result of the outermost if statement.
return if_op.getResult(0);
// Return the reshaped result of the outermost if statement.
auto result = if_op.getResult(0);
auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>(
loc, result.getType(), result, broadcast_shape);
return reshaped_result;
}
};
@ -386,7 +417,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
op.getAttrs());
op->getAttrs());
Value extended_if_lhs_scalar_result =
extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result,
shape_of_lhs, shape_of_rhs);
@ -409,7 +440,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
op.getAttrs());
op->getAttrs());
Value extended_if_rhs_scalar_result =
extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result,
shape_of_lhs, shape_of_rhs);
@ -497,16 +528,17 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
shape::ShapeDialect>();
}
void runOnFunction() override {
// Setup conversion target.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
shape::ShapeDialect, scf::SCFDialect,
tensor::TensorDialect>();
target.addLegalDialect<chlo::HloClientDialect, mhlo::MhloDialect,
StandardOpsDialect, shape::ShapeDialect,
scf::SCFDialect, tensor::TensorDialect>();
target.addLegalOp<FuncOp>();
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)

View File

@ -0,0 +1,26 @@
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
// CHECK-LABEL: func @minimum_broadcast_shapes
func @minimum_broadcast_shapes(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>)
-> (tensor<?xindex>, tensor<?xindex>) {
%0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs :
tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
return %0, %1 : tensor<?xindex>, tensor<?xindex>
}
// -----
func @minimum_broadcast_shapes_mismatch_operand_and_result_count(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>) {
// expected-error @+1{{number of operand shapes (2) does not match number of result shapes (1)}}
%0 = chlo.minimum_broadcast_shapes %lhs, %rhs :
tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return
}
// -----
func @minimum_broadcast_shapes_one_operand(%arg: tensor<?xindex>) {
// expected-error @+1{{number of operand shapes (1) should be >= 2}}
%0 = chlo.minimum_broadcast_shapes %arg : tensor<?xindex> -> tensor<?xindex>
return
}

View File

@ -954,6 +954,28 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor<?xf32> {
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> ()>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
// CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex>
func @dynamic_broadcast_in_dim(%shape: tensor<2xindex>) -> tensor<?x32xf32> {
%cst = mhlo.constant dense<0x7F800000> : tensor<f32>
%result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) {
broadcast_dimensions = dense<> : tensor<0xi64>
} : (tensor<f32>, tensor<2xindex>) -> tensor<?x32xf32>
return %result : tensor<?x32xf32>
}
// CHECK: [[CST:%.*]] = constant
// CHECK: [[INIT:%.*]] = linalg.init_tensor
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-SAME: ins([[CST]] : tensor<f32>) outs([[INIT]] : tensor<?x32xf32>)
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
func @dot_matmul(%arg0: tensor<2x3xf32>,
%arg1: tensor<3x?xf32>) -> tensor<2x?xf32> {
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>,
@ -982,7 +1004,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i8_i8_i32
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
@ -1000,7 +1022,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i16_i16_i32
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
@ -1018,7 +1040,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i32_i32_i32
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
@ -1109,7 +1131,7 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul_i8_i8_i32
// CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
@ -1138,7 +1160,7 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul_i16_i16_i32
// CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
@ -1444,8 +1466,8 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf3
// -----
func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
-> tensor<?x?x?xf32> {
func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x8x?xf32>, %arg1: tensor<2x?x?xf32>)
-> tensor<?x7x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
@ -1463,31 +1485,29 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tenso
padding = dense<[[0], [0]]> : tensor<2x1xi64>,
rhs_dilation = dense<1> : tensor<1xi64>,
window_strides = dense<1> : tensor<1xi64>
} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
} : (tensor<?x8x?xf32>, tensor<2x?x?xf32>) -> tensor<?x7x?xf32>
return %0 : tensor<?x7x?xf32>
}
// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]]
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_1d_input_nwc_filter_wcf
// CHECK-SAME: {dilations = dense<1> : tensor<1xi64>
// CHECK-SAME: strides = dense<1> : tensor<1xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x8x?xf32>, tensor<2x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x7x?xf32>) -> tensor<?x7x?xf32>
// -----
func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>)
-> tensor<?x?x?x?xf32> {
func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3x2x?x?xf32>)
-> tensor<?x2x3x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
@ -1505,33 +1525,29 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?
padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>
} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
} : (tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) -> tensor<?x2x3x?xf32>
return %0 : tensor<?x2x3x?xf32>
}
// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
// CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32>
// -----
func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>)
-> tensor<?x?x?x?x?xf32> {
func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x8x8x8x?xf32>, %arg1: tensor<2x2x2x?x?xf32>)
-> tensor<?x7x7x7x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
@ -1549,30 +1565,24 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tens
padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
rhs_dilation = dense<1> : tensor<3xi64>,
window_strides = dense<1> : tensor<3xi64>
} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?xf32>
} : (tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>) -> tensor<?x7x7x7x?xf32>
return %0 : tensor<?x7x7x7x?xf32>
}
// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
// CHECK: %[[C4:.+]] = constant 4 : index
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]]
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>
// CHECK-SAME: strides = dense<1> : tensor<3xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x7x7x7x?xf32>) -> tensor<?x7x7x7x?xf32>
// -----

View File

@ -199,20 +199,24 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// Handle rank 1 specialization
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
@ -222,12 +226,12 @@ func @addUnrankedUnranked(
// Handle rank 2 specialization
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
@ -237,12 +241,12 @@ func @addUnrankedUnranked(
// Handle rank 3 specialization
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
@ -252,47 +256,30 @@ func @addUnrankedUnranked(
// Handle rank 4 specialization
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]]
// Handle rank 5 specialization
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C6]] : index
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_6]]
// Handle rank 6 specialization
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
// CHECK-NEXT: }
@ -300,7 +287,8 @@ func @addUnrankedUnranked(
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
// CHECK-NEXT: }
@ -325,13 +313,18 @@ func @selectUnrankedUnrankedUnranked(
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[PRED_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %c1 = constant 1 : index
@ -339,15 +332,15 @@ func @selectUnrankedUnrankedUnranked(
// Handle rank 1 specialization
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
@ -357,4 +350,3 @@ func @selectUnrankedUnrankedUnranked(
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>

View File

@ -368,6 +368,16 @@ func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
// -----
// CHECK-LABEL: @concat_1D
// Verifies that an error is not thrown if the inferred type is compatible with
// the result type.
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> {
// expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}}
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
@ -384,14 +394,6 @@ func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<
// -----
func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
// expected-error@+1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<3xi32>'}}
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> {
// expected-error@+1 {{op inferred type(s) 'tensor<3xi32>' are incompatible with return type(s) of operation 'tensor<4xi32>'}}
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>

View File

@ -627,6 +627,7 @@ tf_native_cc_binary(
srcs = [
"converter_gen.cc",
],
compatible_with = get_compatible_with_cloud(),
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",

View File

@ -1192,11 +1192,11 @@ void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
auto cond = symbol_table.lookup<mlir::FuncOp>(
while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
AddCallOpInWhileOpRegion(while_op.cond(), cond);
while_op.removeAttr("cond");
while_op->removeAttr("cond");
auto body = symbol_table.lookup<mlir::FuncOp>(
while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
AddCallOpInWhileOpRegion(while_op.body(), body);
while_op.removeAttr("body");
while_op->removeAttr("body");
});
}
} // namespace

View File

@ -1011,9 +1011,7 @@ def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [
TFL_OperandHasAtleastRank<0, 2>,
TFL_OperandHasAtleastRank<1, 2>,
PredOpTrait<"x and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
PredOpTrait<"y and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
let summary = "Batch Matrix Multiply Operator";
@ -2774,7 +2772,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
def TFL_SelectOp : TFL_Op<"select", [
NoSideEffect,
SameOperandsAndResultsScale,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands have same element type", TFL_TCopVTEtAreSameAt<1, 2>>,
PredOpTrait<"operands and result have same element type",
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
let summary = "Select operator";
@ -2810,8 +2808,9 @@ def TFL_SelectOp : TFL_Op<"select", [
def TFL_SelectV2Op : TFL_Op<"select_v2", [
ResultsBroadcastableShape,
NoSideEffect,
SameOperandsAndResultsScale,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 4>,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands have same element type", TFL_TCopVTEtAreSameAt<1, 2>>,
PredOpTrait<"operands and result have same element type",
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
let summary = "SelectV2 operator";

View File

@ -14,15 +14,16 @@ func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x120x120x8xf32>
}
func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32> {
func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x120x120x8xf32>
%cst_1 = constant dense<0> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x64x64x3xf32>
%1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x64x64x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x60x60x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x60x60x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x120x120x8xf32>
return %2 : tensor<1x120x120x8xf32>
// CHECK-LABEL: testDilatedConvWithNonConstantPadAndCrops
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x120x120x8xf32>
}
@ -73,37 +74,39 @@ func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<4> : tensor<2x2xi32>
%cst_1 = constant dense<0> : tensor<2x2xi32>
%cst_2 = constant dense<0> : tensor<4x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
%1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.Pad"(%1, %cst_2) : (tensor<4x64x64x8xf32>, tensor<4x2xi32>) -> tensor<4x64x64x8xf32>
%3 = "tf.BatchToSpaceND"(%2, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
%4 = "tf.BiasAdd"(%3, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
return %4 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConvWithPad
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<4> : tensor<2x2xi32>
%cst_1 = constant dense<0> : tensor<2x2xi32>
%cst_2 = constant dense<0> : tensor<4x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
%1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.Pad"(%1, %cst_2) : (tensor<4x64x64x8xf32>, tensor<4x2xi32>) -> tensor<4x64x64x8xf32>
%3 = "tf.BatchToSpaceND"(%2, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
%4 = "tf.BiasAdd"(%3, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
return %4 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedDepthWiseConvWithPad
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
@ -235,22 +238,23 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%cst_1 = constant dense<4> : tensor<2x2xi32>
%cst_2 = constant dense<0> : tensor<2x2xi32>
%cst_3 = constant dense<0> : tensor<3x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
%4 = "tf.Pad"(%3, %cst_3) : (tensor<4x64x64xf32>, tensor<3x2xi32>) -> tensor<4x64x64xf32>
%5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
%6 = "tf.BiasAdd"(%5, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %6 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
@ -259,22 +263,23 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%cst_1 = constant dense<4> : tensor<2x2xi32>
%cst_2 = constant dense<0> : tensor<2x2xi32>
%cst_3 = constant dense<0> : tensor<3x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
%4 = "tf.Pad"(%3, %cst_3) : (tensor<4x64x64xf32>, tensor<3x2xi32>) -> tensor<4x64x64xf32>
%5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
%6 = "tf.BiasAdd"(%5, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %6 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
@ -481,3 +486,27 @@ func @testDilatedConv1DWithMixedPostiveAndNegativeAxis(%arg0: tensor<1x128x3xf32
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
}
func @testPaddedDilatedConv(%arg0 : tensor<2x1920x64xf32>) -> tensor<2x1920x128xf32> {
%0 = "tf.Const"() {value = dense<[[0, 0], [2, 0], [0, 0]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
%3 = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%4 = "tf.Const"() {value = dense<0.0> : tensor<3x1x64x128xf32>} : () -> tensor<3x1x64x128xf32>
%5 = "tf.SpaceToBatchND"(%arg0, %1, %3) {device = ""} : (tensor<2x1920x64xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<4x960x64xf32>
%6 = "tf.ExpandDims"(%5, %2) {device = ""} : (tensor<4x960x64xf32>, tensor<i32>) -> tensor<4x960x1x64xf32>
%7 = "tf.Conv2D"(%6, %4) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<4x960x1x64xf32>, tensor<3x1x64x128xf32>) -> tensor<4x958x1x128xf32>
%8 = "tf.Squeeze"(%7) {device = "", squeeze_dims = [2]} : (tensor<4x958x1x128xf32>) -> tensor<4x958x128xf32>
%9 = "tf.Pad"(%8, %0) {device = ""} : (tensor<4x958x128xf32>, tensor<3x2xi32>) -> tensor<4x960x128xf32>
%10 = "tf.BatchToSpaceND"(%9, %1, %3) {device = ""} : (tensor<4x960x128xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x1920x128xf32>
return %10 : tensor<2x1920x128xf32>
// CHECK-LABEL: testPaddedDilatedConv
// CHECK-SAME: ([[INPUT:%.*]]: tensor<2x1920x64xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[FILTER:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x1x64x128xf32>} : () -> tensor<3x1x64x128xf32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) {device = ""} : (tensor<2x1920x64xf32>, tensor<i32>) -> tensor<2x1920x1x64xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {data_format = "NHWC", device = "", dilations = [1, 2, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<2x1920x1x64xf32>, tensor<3x1x64x128xf32>) -> tensor<2x1920x1x128xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {device = "", squeeze_dims = [2]} : (tensor<2x1920x1x128xf32>) -> tensor<2x1920x128xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<2x1920x128xf32>
}

View File

@ -1387,6 +1387,15 @@ func @testBatchMatmulQuant(%arg0 : tensor<1x4x384x32x!quant.uniform<i8:f32, 0.06
%0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32x!quant.uniform<i8:f32, 0.06:-2>>, tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>>
return %0 : tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>>
}
// -----
func @testBatchMatmulHybridQuant(%arg0 : tensor<1x4x384x32xf32>, %arg1 : tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384xf32> {
// CHECK: "tfl.batch_matmul"(%arg0, %arg1)
%0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384xf32>
return %0 : tensor<1x4x384x384xf32>
}
// -----
func @testConcat(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<2x2xi32> {

View File

@ -564,6 +564,78 @@ func @FuseFullyConnectedReshapeAddConstWithActivation(%arg0: tensor<40x37xf32>,
// FOLD: return %[[fc]]
}
// CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConst
func @FuseFullyConnectedReshapeAdd2DConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<4x10xf32>
%shape = constant dense<[1, 40, 4, 10]> : tensor<4xi32>
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
%1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32>
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x4x10xf32>, tensor<4x10xf32>) -> tensor<1x40x4x10xf32>
return %2 : tensor<1x40x4x10xf32>
// CHECK: %[[cst:.*]] = constant dense<2.000000e+00> : tensor<40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]]
// CHECK: return %[[rs]]
}
// CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConstWithActivation
func @FuseFullyConnectedReshapeAdd2DConstWithActivation(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<4x10xf32>
%shape = constant dense<[1, 40, 4, 10]> : tensor<4xi32>
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
%1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32>
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x40x4x10xf32>, tensor<4x10xf32>) -> tensor<1x40x4x10xf32>
return %2 : tensor<1x40x4x10xf32>
// CHECK: %[[cst:.*]] = constant dense<2.000000e+00> : tensor<40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]]
// CHECK: return %[[rs]]
}
// CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConstWithExistingBias
func @FuseFullyConnectedReshapeAdd2DConstWithExistingBias(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> {
%cst = constant dense<3.0> : tensor<40xf32>
%cst2 = constant dense<2.0> : tensor<4x10xf32>
%shape = constant dense<[1, 40, 4, 10]> : tensor<4xi32>
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40xf32>) -> (tensor<40x40xf32>)
%1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32>
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x4x10xf32>, tensor<4x10xf32>) -> tensor<1x40x4x10xf32>
return %2 : tensor<1x40x4x10xf32>
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]]
// CHECK: return %[[rs]]
}
// CHECK-LABEL: @NotFuseFullyConnectedReshapeAdd2DConstIfLastDimIsNotNumElementsOfRhs
func @NotFuseFullyConnectedReshapeAdd2DConstIfLastDimIsNotNumElementsOfRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<20x37xf32>) -> tensor<1x20x4x10xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<4x10xf32>
%shape = constant dense<[1, 20, 4, 10]> : tensor<4xi32>
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<20x37xf32>, none) -> (tensor<40x20xf32>)
%1 = "tfl.reshape"(%0, %shape) : (tensor<40x20xf32>, tensor<4xi32>) -> tensor<1x20x4x10xf32>
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x20x4x10xf32>, tensor<4x10xf32>) -> tensor<1x20x4x10xf32>
return %2 : tensor<1x20x4x10xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1
// CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]]
// CHECK: %[[add:.*]] = "tfl.add"(%[[rs]]
// CHECK: return %[[add]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastableAfter
func @NotReorderReshapeAddIfNotBroadcastableAfter(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
%cst = constant dense<2.0> : tensor<40xf32>
@ -642,6 +714,19 @@ func @NotReorderReshapeAddIfHighDim(%arg0: tensor<1x1x1x1x30x96xf32>) -> tensor<
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeAdd2DConstIfInputIsNotDefinedByFullyConnected
func @NotReorderReshapeAdd2DConstIfInputIsNotDefinedByFullyConnected(%arg0: tensor<8x15xf32>) -> tensor<1x8x3x5xf32> {
%cst = constant dense<2.0> : tensor<3x5xf32>
%shape = constant dense<[1, 8, 3, 5]> : tensor<4xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<8x15xf32>, tensor<4xi32>) -> tensor<1x8x3x5xf32>
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x8x3x5xf32>, tensor<3x5xf32>) -> tensor<1x8x3x5xf32>
return %2 : tensor<1x8x3x5xf32>
// CHECK: %[[rs:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[add:.*]] = "tfl.add"(%[[rs]]
// CHECK: return %[[add]]
}
// CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp
func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
%shape = constant dense<[40, 40]> : tensor<2xi32>
@ -1679,8 +1764,17 @@ func @ReorderReshapex2Add(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x3x4xf32>
// CHECK: return %[[VAL_1]]
}
// CHECK-LABEL: ConvertSliceToIdentity
func @ConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> {
// CHECK-LABEL: ConvertSliceToIdentityI32
func @ConvertSliceToIdentityI32(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> {
%begin = constant dense<0> : tensor<4xi32>
%shape = constant dense<[2,3,4,5]> : tensor<4xi32>
%0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<2x3x4x5xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<2x3x4x5xf32>
return %0 : tensor<2x3x4x5xf32>
// CHECK: return %arg0
}
// CHECK-LABEL: ConvertSliceToIdentityI64
func @ConvertSliceToIdentityI64(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> {
%begin = constant dense<0> : tensor<4xi64>
%shape = constant dense<[2,3,4,5]> : tensor<4xi64>
%0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<2x3x4x5xf32>, tensor<4xi64>, tensor<4xi64>) -> tensor<2x3x4x5xf32>
@ -1688,6 +1782,24 @@ func @ConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
// CHECK: return %arg0
}
// CHECK-LABEL: ConvertSliceToIdentityStaticDimWithShapeWithNeg1
func @ConvertSliceToIdentityStaticDimWithShapeWithNeg1(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> {
%begin = constant dense<0> : tensor<4xi32>
%shape = constant dense<[-1, 3, -1, 5]> : tensor<4xi32>
%0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<2x3x4x5xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<2x3x4x5xf32>
return %0 : tensor<2x3x4x5xf32>
// CHECK: return %arg0
}
// CHECK-LABEL: ConvertSliceToIdentityDynamicDimAndShapeWithNeg1
func @ConvertSliceToIdentityDynamicDimAndShapeWithNeg1(%arg0: tensor<?x3x?x5xf32>) -> tensor<?x3x?x5xf32> {
%begin = constant dense<0> : tensor<4xi32>
%shape = constant dense<[-1, 3, -1, 5]> : tensor<4xi32>
%0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<?x3x?x5xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<?x3x?x5xf32>
return %0 : tensor<?x3x?x5xf32>
// CHECK: return %arg0
}
// CHECK-LABEL: DontConvertSliceToIdentity
func @DontConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> (tensor<2x3x4x4xf32>, tensor<1x2x3x4xf32>) {
%begin0 = constant dense<0> : tensor<4xi64>
@ -1706,6 +1818,28 @@ func @DontConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> (tensor<2x3x4x4x
// CHECK: return %[[SLICE_0]], %[[SLICE_1]] : tensor<2x3x4x4xf32>, tensor<1x2x3x4xf32>
}
// CHECK-LABEL: DontConvertSliceToIdentityNonConstShape
func @DontConvertSliceToIdentityNonConstShape(%arg0: tensor<?xf32>, %arg1: tensor<1xi32>) -> tensor<?xf32> {
%begin = constant dense<0> : tensor<1xi32>
%0 = "tfl.slice"(%arg0, %begin, %arg1) : (tensor<?xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK: %[[BEGIN:.*]] = constant dense<0> : tensor<1xi32>
// CHECK: %[[SLICE:.*]] = "tfl.slice"(%arg0, %[[BEGIN]], %arg1) : (tensor<?xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xf32>
// CHECK: return %[[SLICE]] : tensor<?xf32>
}
// CHECK-LABEL: DontConvertSliceToIdentityDynamicDimButEqualShape
func @DontConvertSliceToIdentityDynamicDimButEqualShape(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%begin = constant dense<0> : tensor<1xi32>
%shape = constant dense<2> : tensor<1xi32>
%0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<?xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK: %[[BEGIN:.*]] = constant dense<0> : tensor<1xi32>
// CHECK: %[[SHAPE:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: %[[SLICE:.*]] = "tfl.slice"(%arg0, %[[BEGIN]], %[[SHAPE]]) : (tensor<?xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xf32>
// CHECK: return %[[SLICE]] : tensor<?xf32>
}
// CHECK-LABEL: @FuseAddWithFullyConnectedWithBias
func @FuseAddWithFullyConnectedWithBias(%arg: tensor<2x512xf32>) -> tensor<2x1024xf32> {
%cst_add = constant dense<2.0> : tensor<512xf32>

View File

@ -22,11 +22,13 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -82,35 +84,77 @@ template <typename Conv2dOpTy>
LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
Conv2dOpTy op, PatternRewriter& rewriter) const {
if (!op.getResult().hasOneUse()) {
return failure();
return rewriter.notifyMatchFailure(
op, "result for current op has more than 1 use");
}
// Make sure Conv2D has 'VALID' padding.
if (op->template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
return failure();
return rewriter.notifyMatchFailure(op,
"Conv2D op doesn't have valid padding");
}
// Make sure dilations are all ones if set.
const ArrayAttr& dilations =
op->template getAttrOfType<ArrayAttr>("dilations");
if (dilations && !TFIntListIsAllOnes(dilations)) {
return failure();
return rewriter.notifyMatchFailure(op, "dilations should be all 1");
}
if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op))
return failure();
if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op)) {
return rewriter.notifyMatchFailure(
op, "op's input is not float or the data format isn't NHWC");
}
// Allow dynamic width and height dimensions only.
auto result_ty = op.getResult().getType().template cast<TensorType>();
if (!result_ty.hasRank() || result_ty.getRank() != 4 ||
result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3))
return failure();
result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) {
return rewriter.notifyMatchFailure(
op, "only dynamic width and height dimensions are allowed");
}
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
// `Squeeze` op.
Operation* prev_op = op.getOperation()->getPrevNode();
if (!prev_op) return failure();
if (!prev_op || prev_op->getNumResults() != 1) {
return rewriter.notifyMatchFailure(
op, "op doesn't have a preceding node that has a single result");
}
if (!prev_op->hasOneUse() || *(prev_op->getResult(0).user_begin()) != op) {
return rewriter.notifyMatchFailure(
op, "op's input isn't produced by previous operation");
}
Operation* next_op = op.getOperation()->getNextNode();
if (!next_op) return failure();
auto tryGetNextNode =
[&rewriter](Operation* current) -> std::pair<LogicalResult, Operation*> {
// Check the current operation has a single result.
if (current->getNumResults() != 1) {
return {
rewriter.notifyMatchFailure(current, "op doesn't have single result"),
nullptr};
}
// Check the current operation has a next node.
Operation* next_op = current->getNextNode();
if (!next_op) {
return {rewriter.notifyMatchFailure(current, "op doesn't have next node"),
nullptr};
}
// Check the current operation's result is used by its successor node.
if (!current->hasOneUse() ||
*(current->getResult(0).user_begin()) != next_op) {
return {
rewriter.notifyMatchFailure(
current, "op's result isn't directly consumed by the next op"),
nullptr};
}
return {LogicalResult::success(), next_op};
};
std::pair<LogicalResult, Operation*> maybeNextNode =
tryGetNextNode(op.getOperation());
if (failed(maybeNextNode.first)) {
return maybeNextNode.first;
}
Operation* next_op = maybeNextNode.second;
TF::ExpandDimsOp expand_op;
TF::SqueezeOp squeeze_op;
@ -119,15 +163,19 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
// Expand/Squeeze op must come in pair.
return failure();
return rewriter.notifyMatchFailure(
op, "ExpandDimsOp and SqueezeOp should come in pair");
}
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
if (!expand_op.getResult().hasOneUse() ||
!squeeze_op.getResult().hasOneUse()) {
return failure();
if (!expand_op.getResult().hasOneUse()) {
return rewriter.notifyMatchFailure(
expand_op, "result for current op has more than 1 use");
}
if (!squeeze_op.getResult().hasOneUse()) {
return rewriter.notifyMatchFailure(
squeeze_op, "result for current op has more than 1 use");
}
// Make sure that the axis in `expand_op` is constant.
if (auto const_op =
llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
@ -141,12 +189,14 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
expand_axis += 4;
}
} else {
return failure();
return rewriter.notifyMatchFailure(
expand_op, "ExpandDimsOp doesn't have a constant axis");
}
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
auto squeeze_dims = squeeze_op.squeeze_dims();
if (squeeze_dims.size() != 1) {
return failure();
return rewriter.notifyMatchFailure(
squeeze_op, "squeeze dims should have exactly 1 dimension specified");
}
int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
if (squeeze_axis < 0) {
@ -154,36 +204,62 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
squeeze_axis += 4;
}
if (squeeze_axis != expand_axis) {
return failure();
return rewriter.notifyMatchFailure(
op, "squeeze axis and expand axis doesn't match");
}
// Update previous/next op pointer.
prev_op = prev_op->getPrevNode();
if (!prev_op) return failure();
next_op = next_op->getNextNode();
if (!next_op) return failure();
Operation* tmp = prev_op->getPrevNode();
if (!tmp || tmp->getNumResults() != 1) {
return rewriter.notifyMatchFailure(
prev_op, "op doesn't have a preceding node that has a single result");
}
if (!tmp->hasOneUse() || *(tmp->getResult(0).user_begin()) != prev_op) {
return rewriter.notifyMatchFailure(
prev_op, "op's input isn't defined by its previous node");
}
prev_op = tmp;
std::pair<LogicalResult, Operation*> maybeNextNode =
tryGetNextNode(next_op);
if (failed(maybeNextNode.first)) {
return maybeNextNode.first;
}
next_op = maybeNextNode.second;
}
// SpaceToBatchND op.
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return failure();
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) {
return rewriter.notifyMatchFailure(prev_op,
"op should be a SpaceToBatchND op");
}
// TODO(b/149936532): Check `padding` input, currently ignored.
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
if (!stb_op.getResult().hasOneUse()) {
return failure();
return rewriter.notifyMatchFailure(
stb_op, "result for current op has more than 1 use");
}
// Pad op.
TF::PadOp pad_op;
// TODO(b/149936532): Currently we just ignore the PadOp. However note that
// in real scenarios this may not always be correct: user can put a PadOp here
// with non-trivial consequences.
ElementsAttr pad_attr;
if (llvm::isa<TF::PadOp>(next_op)) {
pad_op = llvm::cast<TF::PadOp>(next_op);
if (!pad_op.getResult().hasOneUse()) {
return failure();
return rewriter.notifyMatchFailure(
pad_op, "result for current op has more than 1 use");
}
std::pair<LogicalResult, Operation*> maybeNextNode =
tryGetNextNode(next_op);
if (failed(maybeNextNode.first)) {
return maybeNextNode.first;
}
next_op = maybeNextNode.second;
if (!matchPattern(pad_op.paddings(), m_Constant(&pad_attr))) {
// If the padding value isn't constant, we can't determine the padding
// scheme for Conv2D below, in this case just reject the pattern.
return rewriter.notifyMatchFailure(
pad_op, "PadOp's padding value isn't constant");
}
next_op = next_op->getNextNode();
if (!next_op) return failure();
}
// BatchToSpaceND + BiasAdd.
@ -194,33 +270,53 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
// Must be BiasAdd + BatchToSpaceND.
biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
if (!biasadd_op.getResult().hasOneUse()) {
return failure();
return rewriter.notifyMatchFailure(
biasadd_op, "result for current op has more than 1 use");
}
next_op = next_op->getNextNode();
if (!next_op || !llvm::isa<TF::BatchToSpaceNDOp>(next_op)) return failure();
std::pair<LogicalResult, Operation*> maybeNextNode =
tryGetNextNode(next_op);
if (failed(maybeNextNode.first)) {
return maybeNextNode.first;
}
if (!llvm::isa<TF::BatchToSpaceNDOp>(maybeNextNode.second)) {
return rewriter.notifyMatchFailure(
next_op, "op's next node isn't BatchToSpaceND op");
}
next_op = maybeNextNode.second;
bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
} else if (llvm::isa<TF::BatchToSpaceNDOp>(next_op)) {
// BatchToSpaceND + (optional) BiasAdd.
bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
next_op = next_op->getNextNode();
if (next_op && llvm::isa<TF::BiasAddOp>(next_op)) {
Operation* tmp = next_op->getNextNode();
if (tmp && llvm::isa<TF::BiasAddOp>(tmp)) {
if (!bts_op.getResult().hasOneUse()) {
return failure();
return rewriter.notifyMatchFailure(
bts_op, "result for current op has more than 1 use");
}
if (!next_op->hasOneUse() ||
*(next_op->getResult(0).user_begin()) != tmp) {
return rewriter.notifyMatchFailure(
next_op, "op's result isn't directly consumed by the next op");
}
next_op = tmp;
biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
final_op_is_bts = false;
}
} else {
return failure();
return rewriter.notifyMatchFailure(
next_op, "next op is neither BiasAdd nor BatchToSpaceND");
}
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
if (!dilations_attr.hasValue()) return failure();
if (!dilations_attr.hasValue()) {
return rewriter.notifyMatchFailure(op, "failed to extract dilation rate");
}
if (expand_op) {
if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
return failure();
return rewriter.notifyMatchFailure(
stb_op, "SpaceToBatchND op's input should have RankedTensorType");
}
}
@ -255,16 +351,33 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
auto stb_paddings = stb_op.paddings();
auto bts_crops = bts_op.crops();
ElementsAttr stb_paddings_attr, bts_crops_attr;
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) &&
matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
if (stb_paddings_attr.getNumElements() != bts_crops_attr.getNumElements())
return failure();
// padding - crop.
auto paddings = stb_paddings_attr.getValues<IntegerAttr>();
auto crops = bts_crops_attr.getValues<IntegerAttr>();
for (auto it1 = paddings.begin(), it2 = crops.begin();
it1 != paddings.end() && it2 != crops.end(); it1++, it2++) {
if ((*it1).getInt() != (*it2).getInt()) {
if (!matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) ||
!matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
return rewriter.notifyMatchFailure(
op,
"either SpaceToBatchND or BatchToSpaceND "
"doesn't have constant padding/crops value");
}
if (stb_paddings_attr.getType() != bts_crops_attr.getType()) {
return rewriter.notifyMatchFailure(
stb_op,
"SpaceToBatchND op's padding doesn't have same shape/type with "
"BatchToSpaceND op's crops");
}
int64_t m = stb_paddings_attr.getType().getDimSize(0);
// padding - crop.
for (uint64_t i = 0; i < m; ++i) {
for (uint64_t j = 0; j < 2; ++j) {
// `crops` tensor has shape [M, 2], crops[i] = [crop_start, crop_end]
// specifies the amount to crop from input dimension i + 1. If the input
// of `BatchToSpaceND` has been padded explicitly, then we need to
// take into account the additional padding when determining the padding
// scheme for `Conv2D`.
int64_t addtional_pad =
pad_attr ? pad_attr.getValue<IntegerAttr>({i + 1, j}).getInt() : 0;
if (stb_paddings_attr.getValue<IntegerAttr>({i, j}).getInt() +
addtional_pad !=
bts_crops_attr.getValue<IntegerAttr>({i, j}).getInt()) {
op->setAttr("padding", rewriter.getStringAttr("SAME"));
break;
}
@ -316,7 +429,11 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
}
if (final_op_is_bts) {
bts_op.getResult().replaceAllUsesWith(bts_op.input());
if (bts_op.input().getDefiningOp<TF::PadOp>()) {
bts_op.getResult().replaceAllUsesWith(pad_op.input());
} else {
bts_op.getResult().replaceAllUsesWith(bts_op.input());
}
}
stb_op.getResult().dropAllUses();

View File

@ -816,7 +816,7 @@ struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
ConversionPatternRewriter &rewriter) const override {
Value input = operands[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
op.getAttrs());
op->getAttrs());
return success();
}
};
@ -948,7 +948,7 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
// Create a new while op with new operands and updated result types.
auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
operands, op.getAttrs());
operands, op->getAttrs());
converted.removeAttr("T");
(void)UpdateFunctionTypes(rewriter, converted, tensor_list_args);
@ -972,7 +972,7 @@ struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
// Create a new while op with new operands and updated result types.
auto converted = rewriter.create<TF::WhileRegionOp>(
op.getLoc(), result_types, operands, op.getAttrs());
op.getLoc(), result_types, operands, op->getAttrs());
// Inline the regions from the old while into the new one, and apply
// signature conversion to inlined region.

View File

@ -210,6 +210,50 @@ bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
return true;
}
// Returns true if we can eliminate the SliceOp. When the values of `begin` are
// all 0s and `size[i]` is equal to either -1 or `input.shape[i]`
// for each dim i, the output tensor is identical to `input`.
bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) {
// Checks if `begin` and `size` are i32 or i64.
auto begin_attr = begin.dyn_cast<DenseIntElementsAttr>();
auto size_attr = size.dyn_cast<DenseIntElementsAttr>();
if (!begin_attr || !size_attr) {
return false;
}
auto begin_elem_ty = begin_attr.getType().getElementType();
if (!begin_elem_ty.isInteger(32) && !begin_elem_ty.isInteger(64)) {
return false;
}
auto size_elem_ty = size_attr.getType().getElementType();
if (!size_elem_ty.isInteger(32) && !size_elem_ty.isInteger(64)) {
return false;
}
// Checks if `input` is ranked and its rank is equal to number of elements in
// `begin` and `size`.
auto input_ty = input.getType().cast<ShapedType>();
if (!input_ty.hasRank()) {
return false;
}
int64_t rank = input_ty.getRank();
if (rank != begin_attr.getNumElements() ||
rank != size_attr.getNumElements()) {
return false;
}
// Checks if `begin` is all 0s, and `size[i]` is equal to either -1 or
// `input.shape[i]`.
for (uint64_t i = 0; i < rank; ++i) {
if (begin_attr.getValue<APInt>({i}).getSExtValue() != 0) return false;
int64_t si = size_attr.getValue<APInt>({i}).getSExtValue();
if (si != -1 && si != input_ty.getDimSize(i)) return false;
}
return true;
}
// Expand Attribute 'a' to 4D with all 1s except 1 dimension.
// Which dimension depends on 'is_depthwise' is true or false.
ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {

View File

@ -367,6 +367,19 @@ def OperandsBroadcastToOutputType : Constraint<CPred<
def IsTailOfShape : Constraint<CPred<
"TFL::IsTailOfShape($0.getType(), $1.getType())">>;
def Flatten : NativeCodeCall<
"$0.cast<DenseElementsAttr>()"
".reshape(RankedTensorType::get({$0.getType().cast<ShapedType>().getNumElements()}, "
"$0.getType().cast<ShapedType>().getElementType()))">;
def IsLastDimEqualToNumElements : Constraint<CPred<
"$0.getType().cast<ShapedType>().getRank() >= 1 && "
"$0.getType().cast<ShapedType>().getDimSize($0.getType().cast<ShapedType>().getRank() - 1) == "
"$1.getType().cast<ShapedType>().getNumElements()">>;
def IsDefinedByFullyConnectedOp : Constraint<CPred<
"$0.getDefiningOp<TFL::FullyConnectedOp>() != nullptr">>;
// Pattern for skipping Tile if it is mainly for broadcasting and the
// Op is already supporting broadcasting.
multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
@ -446,6 +459,29 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
(IsTailOfShape $lhs, $rhs),
(IsTailOfShape $input1, $input2),
(IsTailOfShape $input2, $input1)]>;
// Move binary op before reshape:
// binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs)))
// This is valid only when the last dimension of lhs is equal to the
// number of elements in constant rhs.
// Therefore, after transformation broadcast of binary op is always
// applied to the last dimension of $input.
def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs ElementsAttr:$rhs_attr), $act_fn),
(TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr)),
$act_fn),
$shape),
[(AnyStaticShapeTensor $input),
(IsTailOfShape $rhs, $lhs),
(IsLastDimEqualToNumElements $input, $rhs),
(HasOneUse $lhs),
// Restrict operands to have at most rank 4 because TFLite binary
// kernel supports up to 4D broadcast.
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $lhs),
(HasRankAtMost<4> $rhs),
(IsDefinedByFullyConnectedOp $input)]>;
}
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
@ -491,6 +527,28 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
(IsTailOfShape $lhs, $rhs),
(IsTailOfShape $input1, $input2),
(IsTailOfShape $input2, $input1)]>;
// Move binary op before reshape:
// binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs)))
// This is valid only when the last dimension of lhs is equal to the
// number of elements in constant rhs.
// Therefore, after transformation broadcast of binary op is always
// applied to the last dimension of $input.
def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs ElementsAttr:$rhs_attr)),
(TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr))),
$shape),
[(AnyStaticShapeTensor $input),
(IsTailOfShape $rhs, $lhs),
(IsLastDimEqualToNumElements $input, $rhs),
(HasOneUse $lhs),
// Restrict operands to have at most rank 4 because TFLite binary
// kernel supports up to 4D broadcast.
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $lhs),
(HasRankAtMost<4> $rhs),
(IsDefinedByFullyConnectedOp $input)]>;
}
// Reorder the element-wise value operations and the element move operations,
@ -760,17 +818,11 @@ foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in {
(AxesIsLastDimension $axes, $logits)]>;
}
class AllElementsAreI64<string val> : Constraint<CPred<
"($0.isa<DenseElementsAttr>() && "
"$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isInteger(64) && "
"std::all_of($0.cast<DenseElementsAttr>().getValues<int64_t>().begin(), "
"$0.cast<DenseElementsAttr>().getValues<int64_t>().end(), "
"[](int64_t v){ return v == " #val# ";}))">>;
def CanOptimizeIdentitySliceOp : Constraint<CPred<
"TFL::CanOptimizeIdentitySliceOp($0, $1, $2)">>;
// Remove Slice ops slicing the whole input tensor, effectively no-op
// Remove Slice ops slicing the whole input tensor, effectively no-op.
def OptimizeSliceOp : Pat<
(TFL_SliceOp:$output $input, (ConstantOp $begin), $shape),
(TFL_SliceOp:$output $input, (ConstantOp $begin), (ConstantOp $size)),
(replaceWithValue $input),
[(AllElementsAreI64<"0"> $begin),
(IsTailOfShape $input, $output),
(IsTailOfShape $output, $input)]>;
[(CanOptimizeIdentitySliceOp $input, $begin, $size)]>;

View File

@ -254,7 +254,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
new_types.push_back(extra_operand.getType());
auto new_while_op = OpBuilder(while_op).create<WhileOp>(
while_op.getLoc(), new_types, operands, while_op.getAttrs());
while_op.getLoc(), new_types, operands, while_op->getAttrs());
new_while_op.cond().takeBody(while_op.cond());
new_while_op.body().takeBody(while_op.body());
while_op.replaceAllUsesWith(

View File

@ -62,8 +62,7 @@ FuncOp createLstmCompositeFunc(mlir::Builder* builder, bool ln, bool cifg) {
auto func_type = builder->getFunctionType(input_types, output_type);
auto func =
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"),
builder->getContext()),
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func")),
"fused_func", func_type, {});
func.addEntryBlock();

View File

@ -38,8 +38,7 @@ FuncOp createMaxUnpoolingFunc(
const SmallVector<mlir::Type, NOutput>& output_types) {
auto func_type = builder->getFunctionType(input_types, output_types);
auto func =
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"),
builder->getContext()),
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func")),
"fused_func", func_type, {});
func.addEntryBlock();

View File

@ -163,11 +163,11 @@ gentbl(
tf_ops_category_list = [
{
"name": "ops_a_m",
"include": "tf.[A-M].*$$",
"include": "tf.[A-M].*$",
},
{
"name": "ops_n_z",
"include": "tf.[N-Z].*$$",
"include": "tf.[N-Z].*$",
},
]
@ -177,11 +177,11 @@ tf_ops_category_list = [
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
(
"-gen-op-decls -op-include-regex='" + target["include"] + "'",
"-gen-op-decls -op-include-regex=" + target["include"],
"ir/tf_" + target["name"] + ".h.inc",
),
(
"-gen-op-defs -op-include-regex='" + target["include"] + "'",
"-gen-op-defs -op-include-regex=" + target["include"],
"ir/tf_" + target["name"] + ".cc.inc",
),
],
@ -198,11 +198,11 @@ gentbl(
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
(
"-gen-op-decls -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ",
"-gen-op-decls -op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]),
"ir/tf_remaining_ops.h.inc",
),
(
"-gen-op-defs -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ",
"-gen-op-defs -op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]),
"ir/tf_remaining_ops.cc.inc",
),
],
@ -964,6 +964,7 @@ cc_library(
"transforms/tpu_space_to_depth_pass.cc",
"transforms/tpu_update_embedding_enqueue_op_inputs.cc",
"transforms/tpu_variable_runtime_reformatting.cc",
"transforms/verify_suitable_for_graph_export_pass.cc",
"translate/breakup-islands.cc",
"translate/tf_executor_to_functional.cc",
"translate/tf_functional_to_executor.cc",
@ -1010,6 +1011,7 @@ cc_library(
":translate_utils",
":unroll_batch_matmul_pass",
":verification_utils",
":verify_suitable_for_graph_export",
":visitor_util",
":xla_sharding_util",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
@ -1151,6 +1153,7 @@ cc_library(
":tf_saved_model_passes",
":translate_utils",
":upgrade_graph",
":verify_suitable_for_graph_export",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/cc/saved_model:loader_lite",
@ -2106,3 +2109,14 @@ cc_library(
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "verify_suitable_for_graph_export",
srcs = ["utils/verify_suitable_for_graph_export.cc"],
hdrs = ["utils/verify_suitable_for_graph_export.h"],
deps = [
":tensorflow",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

View File

@ -226,8 +226,7 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) {
if (isa<IdentityOp>(op)) return true;
// For op's in the Tensorflow dialect, query the dialect.
if (op->getName().getDialect() ==
TF::TensorFlowDialect::getDialectNamespace())
if (isa_and_nonnull<TF::TensorFlowDialect>(op->getDialect()))
return !TensorFlowDialect::CanHaveSideEffects(op);
// Otherwise, conservatively assume that there can be side effects.

View File

@ -413,7 +413,7 @@ void Print(ReplicateOp op, OpAsmPrinter* p) {
// Skip derived `operand_segment_sizes` attribute as custom print format of
// operands holds enough information to calculate these variadic operand list
// lengths.
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/ArrayRef<StringRef>{
p->printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/ArrayRef<StringRef>{
kOperandSegmentSizesAttr});
p->printRegion(op.body(), /*printEntryBlockArgs=*/false);
}

View File

@ -213,7 +213,7 @@ LogicalResult Verify(GraphOp graph) {
void Print(GraphOp graph, OpAsmPrinter &p) {
p << graph.getOperationName();
p.printRegion(graph.getOperation()->getRegion(0));
p.printOptionalAttrDict(graph.getAttrs());
p.printOptionalAttrDict(graph->getAttrs());
}
ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
@ -321,7 +321,7 @@ void Print(IslandOp op, OpAsmPrinter &p) {
// Check if we can print the short "wraps" form: that is if the island
// contains a single operation and the result of this operation are perfectly
// forwarded to the yield.
if (op.getAttrs().empty() && op.WrapsSingleOp()) {
if (op->getAttrs().empty() && op.WrapsSingleOp()) {
Operation &wrapped_op = op.GetBody().front();
YieldOp yield_op = op.GetYield();
// The "wraps" syntax only encodes a single location.
@ -335,7 +335,7 @@ void Print(IslandOp op, OpAsmPrinter &p) {
}
}
p.printRegion(op.getOperation()->getRegion(0));
p.printOptionalAttrDict(op.getAttrs());
p.printOptionalAttrDict(op->getAttrs());
}
ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
@ -449,7 +449,7 @@ void Print(SwitchOp switch_op, OpAsmPrinter &p) {
} else {
p << switch_op.getType(0);
}
p.printOptionalAttrDict(switch_op.getAttrs());
p.printOptionalAttrDict(switch_op->getAttrs());
}
} // anonymous namespace
@ -525,7 +525,7 @@ void Print(SwitchNOp switchn, OpAsmPrinter &p) {
p << ")";
}
p << " : " << switchn.getType(0);
p.printOptionalAttrDict(switchn.getAttrs(), {"num_outs"});
p.printOptionalAttrDict(switchn->getAttrs(), {"num_outs"});
}
ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) {
@ -655,7 +655,7 @@ void Print(MergeOp merge, OpAsmPrinter &p) {
p << output_type;
}
p.printOptionalAttrDict(merge.getAttrs());
p.printOptionalAttrDict(merge->getAttrs());
}
ParseResult ParseMergeOp(OpAsmParser &parser, OperationState &result) {
@ -723,7 +723,7 @@ void Print(EnterOp enter, OpAsmPrinter &p) {
p << enter.getType(0);
}
p.printOptionalAttrDict(enter.getAttrs(),
p.printOptionalAttrDict(enter->getAttrs(),
{"frame_name", "parallel_iterations", "is_constant"});
}
@ -843,7 +843,7 @@ void Print(ExitOp exit, OpAsmPrinter &p) {
p << exit.getOperationName() << ' ';
p.printOperands(exit.getOperands());
p << " : " << exit.getType(0);
p.printOptionalAttrDict(exit.getAttrs());
p.printOptionalAttrDict(exit->getAttrs());
}
ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) {
@ -887,7 +887,7 @@ void Print(LoopCondOp loop_cond, OpAsmPrinter &p) {
p << " : " << loop_cond.input().getType();
}
p.printOptionalAttrDict(loop_cond.getAttrs());
p.printOptionalAttrDict(loop_cond->getAttrs());
}
ParseResult ParseLoopCondOp(OpAsmParser &parser, OperationState &result) {

View File

@ -3887,6 +3887,8 @@ Returns the input tensor otherwise.
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasFolder = 1;
}
def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> {
@ -8563,6 +8565,28 @@ the result here is consistent with a truncating divide. E.g.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ModelDatasetOp : TF_Op<"ModelDataset", [NoSideEffect]> {
let summary = "Identity transformation that models performance.";
let description = [{
Identity transformation that models performance.
}];
let arguments = (ins
Arg<TF_VariantTensor, [{A variant tensor representing the input dataset.}]>:$input_dataset,
DefaultValuedAttr<I64Attr, "0">:$algorithm,
DefaultValuedAttr<I64Attr, "0">:$cpu_budget,
DefaultValuedAttr<I64Attr, "0">:$ram_budget,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
TF_VariantTensor:$handle
);
}
def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x * y element-wise.";
@ -9217,6 +9241,31 @@ def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAnd
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_OptimizeDatasetV2Op : TF_Op<"OptimizeDatasetV2", [NoSideEffect]> {
let summary = [{
Creates a dataset by applying related optimizations to `input_dataset`.
}];
let description = [{
Creates a dataset by applying related optimizations to `input_dataset`.
}];
let arguments = (ins
Arg<TF_VariantTensor, [{A variant tensor representing the input dataset.}]>:$input_dataset,
Arg<TF_StrTensor, [{A `tf.string` vector `tf.Tensor` identifying user enabled optimizations.}]>:$optimizations_enabled,
Arg<TF_StrTensor, [{A `tf.string` vector `tf.Tensor` identifying user disabled optimizations.}]>:$optimizations_disabled,
Arg<TF_StrTensor, [{A `tf.string` vector `tf.Tensor` identifying optimizations by default.}]>:$optimizations_default,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
DefaultValuedAttr<StrArrayAttr, "{}">:$optimization_configs
);
let results = (outs
TF_VariantTensor:$handle
);
}
def TF_OptionalGetValueOp : TF_Op<"OptionalGetValue", [NoSideEffect]> {
let summary = [{
Returns the value stored in an Optional variant or raises an error if none exists.
@ -9305,6 +9354,8 @@ This is the opposite of `unpack`.
return Verify(*this);
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
}

View File

@ -888,7 +888,7 @@ class CaseOrIfRegionEliminatePassThrough
// Create new case/if region op.
auto new_op = rewriter.create<CaseOrIfRegionOp>(
op.getLoc(), new_result_types, op.getOperand(), op.getAttrs(),
op.getLoc(), new_result_types, op.getOperand(), op->getAttrs(),
op.getNumRegions());
int next_index = 0;
@ -2092,6 +2092,19 @@ static LogicalResult Verify(EmptyTensorListOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// EnsureShapeOp
//===----------------------------------------------------------------------===//
OpFoldResult EnsureShapeOp::fold(llvm::ArrayRef<mlir::Attribute>) {
ShapedType type = input().getType().dyn_cast<ShapedType>();
if (!type || !type.hasRank()) return {};
// If shape attribute equals input operand's type's shape, fold it to input.
if (type.getShape() == shape()) return input();
// Else retain to enable failing dynamically.
return {};
}
//===----------------------------------------------------------------------===//
// EqualOp
//===----------------------------------------------------------------------===//

View File

@ -302,6 +302,49 @@ OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
return slice_op.input();
}
// Convert Pack to Reshape when there is only one operand to be packed.
// For example,
//
// %0 = tf.Pack(%input) {axis = 0} // %input : tensor<2x3xf32>
//
// can be canonicalized to
//
// %shape = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>}
// %0 = tf.Reshape(%input, %shape)
struct ConvertPackToReshape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PackOp pack_op,
PatternRewriter &rewriter) const override {
// Check if there is only one operand to be packed.
if (pack_op.N() != 1) {
return failure();
}
// Check if input and output are static.
auto input_ty = pack_op.getOperand(0).getType().cast<ShapedType>();
auto output_ty = pack_op.output().getType().cast<ShapedType>();
if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) {
return failure();
}
// Create constant shape for reshape.
auto type =
RankedTensorType::get(output_ty.getRank(), rewriter.getIntegerType(64));
auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape());
auto shape = rewriter.create<ConstOp>(pack_op.getLoc(), shape_attr);
rewriter.replaceOpWithNewOp<ReshapeOp>(pack_op, output_ty,
pack_op.getOperand(0), shape);
return success();
}
};
void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ConvertPackToReshape>(context);
}
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//
@ -2587,7 +2630,7 @@ class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
TF::FusedBatchNormV3Op::getOperationName(),
tf_fused_batch_norm_op.getOperands(),
new_result_types,
tf_fused_batch_norm_op.getAttrs());
tf_fused_batch_norm_op->getAttrs());
Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
rewriter.replaceOp(tf_fused_batch_norm_op,
@ -3029,9 +3072,9 @@ struct WhileRegionEliminatePassThrough
}
// Create the new while operation.
auto new_while_op =
rewriter.create<WhileRegionOp>(while_op.getLoc(), new_result_types,
new_while_operands, while_op.getAttrs());
auto new_while_op = rewriter.create<WhileRegionOp>(
while_op.getLoc(), new_result_types, new_while_operands,
while_op->getAttrs());
// Move region bodies to the new while.
rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),

View File

@ -1643,3 +1643,39 @@ func @testFoldStridedSliceShapeWithEmptySlice(%arg0: tensor<?x1x2x3xf32>) -> (te
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldEnsureShapeOp
func @testFoldEnsureShapeOp(%arg0: tensor<10x20xf32>) -> (tensor<10x20xf32>, tensor<20x10xf32>) {
%0 = "tf.EnsureShape"(%arg0) {shape = #tf.shape<10x20>} : (tensor<10x20xf32>) -> tensor<10x20xf32>
// Failing case which should not be folded.
// CHECK: %[[NF:.*]] = "tf.EnsureShape"(%arg0) {shape = #tf.shape<20x10>}
%1 = "tf.EnsureShape"(%arg0) {shape = #tf.shape<20x10>} : (tensor<10x20xf32>) -> tensor<20x10xf32>
// CHECK: return %arg0, %[[NF]]
return %0, %1: tensor<10x20xf32>, tensor<20x10xf32>
}
// CHECK-LABEL: testConvertPackToReshapeAxis0
func @testConvertPackToReshapeAxis0(%arg0: tensor<2x3xf32>) -> tensor<1x2x3xf32> {
%0 = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x3xf32>) -> tensor<1x2x3xf32>
return %0 : tensor<1x2x3xf32>
// CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
// CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x3xf32>, tensor<3xi64>) -> tensor<1x2x3xf32>
// CHECK: return %[[RESHAPE]] : tensor<1x2x3xf32>
}
// CHECK-LABEL: testConvertPackToReshapeAxis1
func @testConvertPackToReshapeAxis1(%arg0: tensor<2x3xf32>) -> tensor<2x1x3xf32> {
%0 = "tf.Pack"(%arg0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1x3xf32>
return %0 : tensor<2x1x3xf32>
// CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[2, 1, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
// CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x3xf32>, tensor<3xi64>) -> tensor<2x1x3xf32>
// CHECK: return %[[RESHAPE]] : tensor<2x1x3xf32>
}
// CHECK-LABEL: testDontConvertPackToReshapeDynamicShape
func @testDontConvertPackToReshapeDynamicShape(%arg0: tensor<2x?xf32>) -> tensor<1x2x?xf32> {
%0 = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x?xf32>) -> tensor<1x2x?xf32>
return %0 : tensor<1x2x?xf32>
// CHECK: %[[PACK:.*]] = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x?xf32>) -> tensor<1x2x?xf32>
// CHECK: return %[[PACK]] : tensor<1x2x?xf32>
}

View File

@ -1,10 +1,29 @@
// RUN: tf-opt %s -tf-drop-while-shape-invariant | FileCheck %s
// RUN: tf-opt %s -tf-drop-while-shape-invariant-in-device-cluster | FileCheck -check-prefix=IN-CLUSTER %s
// CHECK-LABEL: while_shape_invariant
func @while_cond(%arg0: tensor<*xf32>) -> tensor<i1> {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body(%arg0: tensor<*xf32>) -> (tensor<*xf32>) {
%0 = "tf.SomeOp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// Test that -tf-drop-while-shape-invariant-in-device-cluster pass does not drop
// the shape_invariant attribute from While/WhileRegion ops outside the device
// cluster, while the other pass drops them.
// CHECK-LABEL: while_shape_invariant_outside_cluster
// CHECK-NOT: shape_invariant
func @while_shape_invariant(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
// IN-CLUSTER-LABEL: while_shape_invariant_outside_cluster
func @while_shape_invariant_outside_cluster(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
// IN-CLUSTER: shape_invariant
%0 = "tf.While"(%arg0) {cond = @while_cond, body = @while_body, is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
// IN-CLUSTER: shape_invariant
%1 = "tf.WhileRegion"(%arg0) ( {
^cond(%carg0: tensor<*xf32>):
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@ -18,12 +37,28 @@ func @while_shape_invariant(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf3
return %0, %1 : tensor<*xf32>, tensor<*xf32>
}
func @while_cond(%arg0: tensor<*xf32>) -> tensor<i1> {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
return %0 : tensor<i1>
}
// Test that both passes drop the shape_invariant attribute from
// While/WhileRegion ops within a cluster.
func @while_body(%arg0: tensor<*xf32>) -> (tensor<*xf32>) {
%0 = "tf.SomeOp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: while_shape_invariant_within_cluster
// CHECK-NOT: shape_invariant
// IN-CLUSTER-LABEL: while_shape_invariant_within_cluster
// IN-CLUSTER-NOT: shape_invariant
func @while_shape_invariant_within_cluster(%arg0: tensor<4xf32>) {
"tf_device.cluster"() ( {
%0 = "tf.While"(%arg0) {cond = @while_cond, body = @while_body, is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
%1 = "tf.WhileRegion"(%arg0) ( {
^cond(%carg0: tensor<*xf32>):
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
"tf.Yield"(%2) : (tensor<i1>) -> ()
}, {
^body(%barg0: tensor<*xf32>):
%2 = "tf.SomeOp"(%barg0) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}) {is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
tf_device.return
}) {} : () -> ()
return
}

View File

@ -213,11 +213,11 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) {
} : (tensor<*xf32>) -> (tensor<*xf32>)
// CHECK: [[Result0:%.*]] = "tf.WhileRegion"
// CHECK: [[ResultCast0:%.*]] = "tf.Cast"
// CHECK: [[Result1:%.*]] = call @testWhileCond([[ResultCast0]])
// CHECK: ^bb0(%[[CARG0:.*]]: tensor<4xf32>
// CHECK: [[Result1:%.*]] = call @testWhileCond(%[[CARG0]])
// CHECK: "tf.Yield"([[Result1]])
// CHECK: [[ResultCast1:%.*]] = "tf.Cast"
// CHECK: [[Result2:%.*]] = call @testWhileBody([[ResultCast1]])
// CHECK: ^bb0(%[[BARG0:.*]]: tensor<4xf32>
// CHECK: [[Result2:%.*]] = call @testWhileBody(%[[BARG0]])
// CHECK: "tf.Yield"([[Result2]])
// CHECK: return [[Result0]]
return %1 : tensor<*xf32>

View File

@ -0,0 +1,34 @@
load("//tensorflow:tensorflow.bzl", "filegroup")
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
licenses(["notice"])
glob_lit_tests(
data = [
":debug_info_files",
":test_utilities",
],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["pbtxt"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-mlir-translate",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)
# Bundle together all the debug info files that are used by the tests.
filegroup(
name = "debug_info_files",
srcs = glob(
[
"**/*.debug",
],
),
)

View File

@ -0,0 +1,137 @@
#RUN: tf-mlir-translate --savedmodel-signaturedefs-to-mlir-lite -tf-savedmodel-tags=serve,tpu %p | FileCheck %s
# Test importing a saved model with 2 signatures that are using a same
# BatchFunction Op, which references to a same inference_func from graph_def
# library. The result should be that both signatures uses the same
# BatchFunction Op (the shared_name is the same) and the same copy of
# inference_func.
# CHECK: func @predict0
# CHECK: f = @inference_func[[post_fix:[^,]*]]
# CHECK-SAME: shared_name = "batch"
# CHECK: func @predict1
# CHECK: f = @inference_func[[post_fix]],
# CHECK-SAME: shared_name = "batch"
meta_graphs: {
meta_info_def: {
tags: ["serve", "tpu"]
}
graph_def: {
node: {
name: "input0"
op: "Placeholder"
attr: {
key: "dtype"
value: {
type: DT_STRING
}
}
}
node: {
name: "input1"
op: "Placeholder"
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "batch_func"
op: "BatchFunction"
input: ["input1"]
attr: {
key: "Tcaptured"
value: {
list: {
type: []
}
}
}
attr: {
key: "Tin"
value: {
list: {
type: [DT_INT32]
}
}
}
attr: {
key: "Tout"
value: {
list: {
type: [DT_FLOAT, DT_FLOAT]
}
}
}
attr: {
key: "f"
value: {
func: {
name: "inference_func"
}
}
}
attr: {
key: "shared_name"
value: {
s: "batch"
}
}
}
library: {
function {
signature {
name: "inference_func"
input_arg {
name: "arg0"
type: DT_FLOAT
}
}
ret {
key: "retval0"
value: "arg0"
}
}
}
}
signature_def: {
key: "predict0"
value: {
inputs: {
key: "inputs"
value: {
name: "input0"
dtype: DT_STRING
}
}
outputs: {
key: "outputs"
value: {
name: "batch_func:0"
}
}
}
}
signature_def: {
key: "predict1"
value: {
inputs: {
key: "tf_example_input"
value: {
name: "input0"
dtype: DT_STRING
}
}
outputs: {
key: "outputs"
value: {
name: "batch_func:1"
}
}
}
}
}

View File

@ -0,0 +1,39 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s
node {
name: "input"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
}
node {
name: "func0"
op: "func_name"
input: "input"
}
library {
function {
signature {
name: "func_name"
input_arg {
name: "arg0"
type: DT_BOOL
}
}
ret {
key: "retval0"
value: "arg0"
}
attr: {
key: "_input_shapes"
value: {
}
}
}
}

View File

@ -0,0 +1,76 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Sub -o - | FileCheck %s
node {
name: "Add"
op: "Add"
input: "input0"
input: "input1"
# If device type or id doesn't exist, assign a default one (device:CPU:0).
device: "/job:localhost/replica:0/task:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "Mul"
op: "Mul"
input: "Add"
input: "Add"
# Empty device name should be kept untouched.
device: ""
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "Sub"
op: "Sub"
input: "Add"
input: "Mul"
# Device name is not modified if complete
device: "/job:localhost/replica:0/task:0/device:CPU:1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
node {
name: "input1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
versions {
producer: 27
}
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<*xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input0,input1"
# CHECK-SAME: outputs = "Sub"
# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
# CHECK: %[[mul:.*]], %[[mul_control:.*]] = tf_executor.island wraps "tf.Mul"(%[[add]], %[[add]]) {device = ""}
# CHECK: %[[sub:.*]], %[[sub_control:.*]] = tf_executor.island wraps "tf.Sub"(%[[add]], %[[mul]]) {device = "/job:localhost/replica:0/task:0/device:CPU:1"}

View File

@ -126,3 +126,20 @@ func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<
return %2#0 : tensor<1x112x112x64xf32>
}
// CHECK-LABEL: func @fold_into_pad_with_extra_uses
func @fold_into_pad_with_extra_uses(%arg0: tensor<1x2x4x4x3xf32>) -> (tensor<1x2x3x4x4xf32>, tensor<1x2x3x6x6xf32>) {
// CHECK: %[[PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>}
// CHECK: %[[TRANSPOSE_OP:[0-9]*]] = "tf.Transpose"(%arg0, %[[PERM]])
// CHECK: %[[PADDING:[0-9]*]] = "tf.Const"() {value = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<5x2xi32>}
// CHECK: %[[PAD_OP:[0-9]*]] = "tf.Pad"(%arg0, %[[PADDING]])
// CHECK: %[[DUP_TRANSPOSE_OP:[0-9]*]] = "tf.Transpose"(%[[PAD_OP]], %[[PERM]])
// CHECK: return %[[TRANSPOSE_OP]], %[[DUP_TRANSPOSE_OP]]
%0 = "tf.Const"() {value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
%1 = "tf.Const"() {value = dense<[[0, 0], [0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<5x2xi32>} : () -> tensor<5x2xi32>
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x2x4x4x3xf32>, tensor<5xi32>) -> tensor<1x2x3x4x4xf32>
%3 = "tf.Pad"(%2, %1) : (tensor<1x2x3x4x4xf32>, tensor<5x2xi32>) -> tensor<1x2x3x6x6xf32>
return %2, %3 : tensor<1x2x3x4x4xf32>, tensor<1x2x3x6x6xf32>
}

View File

@ -1910,6 +1910,21 @@ func @convert_gather_nd(%arg0: tensor<98x128xf32>, %arg1: tensor<4x64xi32>) -> t
return %0 : tensor<4x64x128xf32>
}
// CHECK-LABEL: func @convert_gather_transpose(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x256xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>) -> tensor<4x128xf32> {
// CHECK: %[[VAL_2:.*]] = "tf.Const"{{.*}}value = dense<[1, 0]> : tensor<2xi64>
// CHECK: %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<256x128xf32>
// CHECK: %[[VAL_4:.*]] = "tf.GatherNd"(%[[VAL_3]], %[[VAL_1]]) : {{.*}} -> tensor<4x128xf32>
// CHECK: return %[[VAL_4]]
// CHECK: }
// Test the case when start_index_map isn't an iota what requires a transpose to
// convert it to tf.GatherNd.
func @convert_gather_transpose(%arg0: tensor<128x256xf32>, %arg1: tensor<4x1xi32>) -> tensor<4x128xf32> {
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<1> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<1> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[128, 1]> : tensor<2xi64>} : (tensor<128x256xf32>, tensor<4x1xi32>) -> tensor<4x128xf32>
return %0 : tensor<4x128xf32>
}
// CHECK-LABEL: func @convert_dynamic_slice(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<7x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<i32>,

View File

@ -0,0 +1,20 @@
// RUN: tf-mlir-translate -mlir-to-graphdef -tf-export-entry-func-to-flib %s -o - 2>&1 | FileCheck %s
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 458 : i32}} {
func @main() {
tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.Const"() {device = "TPU:0", name = "const", dtype = "tfdtype$DT_INT32", value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
tf_executor.fetch
}
return
}
}
// CHECK-NOT: node
// CHECK: library
// CHECK-NEXT: function
// CHECK-NEXT: signature
// CHECK-NEXT: name: "main"
// CHECK: node_def
// CHECK: op: "Const"

View File

@ -9,7 +9,7 @@ func @main() {
return
}
// CHECK: Functions must be of a single Graph with single op Islands: only single block functions are supported.
// CHECK: functions must be of a single Graph with single op Islands: only single block functions are supported
// -----
@ -19,7 +19,7 @@ func @main() {
return
}
// CHECK: Functions must be of a single Graph with single op Islands: first op in function is not a tf_executor.graph.
// CHECK: functions must be of a single Graph with single op Islands: first op in function is not a tf_executor.graph
// -----
@ -33,7 +33,7 @@ func @main() {
return
}
// CHECK: Functions must be of a single Graph with single op Islands: function does not only contain a single tf_executor.graph.
// CHECK: functions must be of a single Graph with single op Islands: function does not only contain a single tf_executor.graph
// -----
@ -47,7 +47,7 @@ func @main() {
return
}
// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op.
// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op
// -----
@ -63,7 +63,7 @@ func @main() {
return
}
// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op.
// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op
// -----
@ -78,7 +78,7 @@ func @main() {
return
}
// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op.
// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op
// -----
@ -93,4 +93,4 @@ func @main(%arg0: tensor<i32>, %arg1: tensor<i32>) {
return
}
// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op.
// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op

View File

@ -470,8 +470,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK: tf.TensorListSetItem{{.*}}: (tensor<!tf.variant<tensor<2x2xf32>>>, tensor<i32>, tensor<2x2xf32>) -> tensor<!tf.variant<tensor<2x2xf32>>>
%6 = "tf.TensorListSetItem"(%3, %4, %5) {device = ""} : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>, tensor<2x2xf32>)-> tensor<*x!tf.variant>
%7 = "tf.Const"() {device = "", value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%8 = "tf.StopGradient"(%6) : (tensor<*x!tf.variant>) -> tensor<*x!tf.variant>
// CHECK: tf.TensorListStack{{.*}}: (tensor<!tf.variant<tensor<2x2xf32>>>, tensor<i32>) -> tensor<?x2x2xf32>
%8 = "tf.TensorListStack"(%6, %7) {device = "", num_elements = -1 : i64} : (tensor<*x!tf.variant>, tensor<i32>) -> tensor<*xf32>
%9 = "tf.TensorListStack"(%8, %7) {device = "", num_elements = -1 : i64} : (tensor<*x!tf.variant>, tensor<i32>) -> tensor<*xf32>
tf_executor.yield
}
tf_executor.fetch

View File

@ -27,19 +27,16 @@ from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# CHECK: func {{.*}} tf_saved_model.exported_names = ["key_1"]
# CHECK: "tf.If"
# CHECK-SAME: else_branch = @[[else_1:"key_1/[a-zA-Z_0-9]+"]]
# CHECK-SAME: then_branch = @[[then_1:"key_1/[a-zA-Z_0-9]+"]]
# CHECK-SAME: else_branch = @[[else:[a-zA-Z_0-9]+]]
# CHECK-SAME: then_branch = @[[then:[a-zA-Z_0-9]+]]
# CHECK: func {{.*}} tf_saved_model.exported_names = ["key_2"]
# CHECK: "tf.If"
# CHECK-SAME: else_branch = @[[else_2:"key_2/[a-zA-Z_0-9]+"]]
# CHECK-SAME: then_branch = @[[then_2:"key_2/[a-zA-Z_0-9]+"]]
# CHECK-SAME: else_branch = @[[else]]
# CHECK-SAME: then_branch = @[[then]]
# CHECK: func private @[[else_1]](
# CHECK: func private @[[then_1]](
# CHECK: func private @[[else_2]](
# CHECK: func private @[[then_2]](
# CHECK: func private @[[else]](
# CHECK: func private @[[then]](
def Test():

View File

@ -39,21 +39,8 @@ void EnableLogging(PassManager *pm) {
} // namespace
namespace TFTPU {
namespace {
void AddGraphExportLoweringPasses(OpPassManager &pm) {
auto add_pass = [&](std::unique_ptr<Pass> pass) {
pm.addNestedPass<FuncOp>(std::move(pass));
pm.addPass(CreateBreakUpIslandsPass());
};
add_pass(CreateFunctionalToExecutorDialectConversionPass());
add_pass(TFDevice::CreateReplicateToIslandPass());
add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
pm.addNestedPass<FuncOp>(CreateTPUDevicePropagationPass());
pm.addPass(createSymbolDCEPass());
}
tensorflow::Status RunTPUBridge(
ModuleOp module, bool enable_logging,
llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
@ -68,7 +55,7 @@ tensorflow::Status RunTPUBridge(
pipeline_builder(bridge);
// Add set of passes to lower back to graph (from tf_executor).
AddGraphExportLoweringPasses(bridge);
TF::AddGraphExportLoweringPasses(bridge);
// Run the bridge on the module, in case of failure, the `diag_handler`
// converts MLIR errors emitted to the MLIRContext into a tensorflow::Status.
@ -110,13 +97,19 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
func_pm.addPass(CreateTPUHostComputationExpansionPass());
func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
}
// Run another shape inference pass because resource decomposition might have
// created new partial types.
pm.addPass(TF::CreateTFShapeInferencePass());
// Note that the region-based control-flow produced here still contains
// function call ops which get inlined by the subsequent inliner pass.
pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
pm.addPass(mlir::createInlinerPass());
pm.addNestedPass<FuncOp>(
TF::CreateDropWhileShapeInvariantInDeviceClusterPass());
// Run another shape inference pass because resource decomposition might have
// created new partial types. Also, after dropping `shape_invariant` attribute
// from While/WhileRegion ops within cluster would lead to more precise
// shapes.
pm.addPass(TF::CreateTFShapeInferencePass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addPass(CreateTPUClusterCleanupAttributesPass());
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
pm.addNestedPass<FuncOp>(createCSEPass());
@ -166,6 +159,21 @@ tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging) {
namespace TF {
void AddGraphExportLoweringPasses(OpPassManager &pm) {
auto add_pass = [&](std::unique_ptr<Pass> pass) {
pm.addNestedPass<FuncOp>(std::move(pass));
pm.addPass(CreateBreakUpIslandsPass());
};
add_pass(CreateFunctionalToExecutorDialectConversionPass());
add_pass(TFDevice::CreateReplicateToIslandPass());
add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
pm.addPass(TFTPU::CreateTPUDevicePropagationPass());
pm.addPass(createSymbolDCEPass());
pm.addPass(CreateVerifySuitableForExportPass());
}
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
bool enable_logging,
bool enable_inliner) {

View File

@ -20,20 +20,32 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
namespace mlir {
namespace TFTPU {
extern void AddGraphExportLoweringPasses(OpPassManager &pm);
} // namespace TFTPU
} // namespace mlir
namespace {
// Registers an existing pipeline builder function.
// Registers a pipeline builder function for TF TPU bridge.
mlir::PassPipelineRegistration<> tpu_pipeline(
"tf-tpu-bridge",
"Run all the passes involved in transforming the graph before execution so "
"that it is suitable for targeting TPUs.",
mlir::TFTPU::CreateTPUBridgePipeline);
// Registers an existing pipeline builder function.
// Registers a pipeline builder function for TF TPU V1 bridge.
mlir::PassPipelineRegistration<> tpu_pipeline_v1(
"tf-tpu-bridge-v1",
"Run all the passes involved in transforming a TensorFlow V1 graph before "
"execution so that it is suitable for targeting TPUs.",
mlir::TFTPU::CreateTPUBridgePipelineV1);
// Registers a pipeline builder function for TF Graph export.
mlir::PassPipelineRegistration<> tpu_export(
"tf-graph-export",
"Run passes to prepare for exporting module back to TF Graph.",
mlir::TF::AddGraphExportLoweringPasses);
} // anonymous namespace

View File

@ -112,7 +112,7 @@ void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
builder->setInsertionPoint(cluster_op);
auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
cluster_op.getLoc(), outlined_func.getType().getResults(),
live_ins.getArrayRef(), cluster_op.getAttrs());
live_ins.getArrayRef(), cluster_op->getAttrs());
cluster_op.replaceAllUsesWith(cluster_func_op);
cluster_op.erase();

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
@ -32,22 +33,50 @@ class DropWhileShapeInvariantPass
void runOnFunction() override;
};
// Drop `shape_invariant` attribute from tf.While and tf.WhileRegion op only
// inside device cluster. This would allow shape inference pass to further
// refine operand/result shapes of these ops. This is only safe to do when
// compiling to XLA.
class DropWhileShapeInvariantInDeviceClusterPass
: public PassWrapper<DropWhileShapeInvariantInDeviceClusterPass,
FunctionPass> {
void runOnFunction() override;
};
void DropWhileShapeInvariantAttr(Operation* op) {
if (llvm::isa<WhileOp, WhileRegionOp>(op))
op->removeAttr(kShapeInvariantAttr);
}
void DropWhileShapeInvariantPass::runOnFunction() {
getFunction().walk([](Operation* op) {
if (llvm::isa<WhileOp, WhileRegionOp>(op))
op->removeAttr(kShapeInvariantAttr);
getFunction().walk([](Operation* op) { DropWhileShapeInvariantAttr(op); });
}
void DropWhileShapeInvariantInDeviceClusterPass::runOnFunction() {
getFunction().walk([](tf_device::ClusterOp cluster) {
cluster.walk([](Operation* op) { DropWhileShapeInvariantAttr(op); });
});
}
static PassRegistration<DropWhileShapeInvariantPass> pass(
static PassRegistration<DropWhileShapeInvariantPass> drop_shape_invariant_pass(
"tf-drop-while-shape-invariant",
"Drop `shape_invariant` attrbute from While/WhileRegion ops.");
static PassRegistration<DropWhileShapeInvariantInDeviceClusterPass>
drop_shape_invariant_in_cluster_pass(
"tf-drop-while-shape-invariant-in-device-cluster",
"Drop `shape_invariant` attrbute from While/WhileRegion ops inside "
"device cluster.");
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateDropWhileShapeInvariantPass() {
return std::make_unique<DropWhileShapeInvariantPass>();
}
std::unique_ptr<OperationPass<FuncOp>>
CreateDropWhileShapeInvariantInDeviceClusterPass() {
return std::make_unique<DropWhileShapeInvariantInDeviceClusterPass>();
}
} // namespace TF
} // namespace mlir

View File

@ -63,7 +63,7 @@ YieldOp CreateCall(Operation* op, FuncOp func, Region& caller_region,
Block* entry = builder.createBlock(&caller_region);
if (use_region_args) {
entry->addArguments(args.getType());
entry->addArguments(func.getType().getInputs());
args = entry->getArguments();
}
llvm::SmallVector<Value, 4> casted_args;

View File

@ -156,7 +156,7 @@ class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
// The fused contraction has the same attributes as the original
// contraction, with two additions: the list of ops which have been fused
// together; epsilon (only with FusedBatchNorm).
std::vector<NamedAttribute> attrs = contraction.getAttrs();
std::vector<NamedAttribute> attrs = contraction->getAttrs();
ArrayAttr fused_ops_attr = ArrayAttr::get(context, fused_ops);
attrs.push_back(
NamedAttribute(Identifier::get("fused_ops", context), fused_ops_attr));

View File

@ -96,7 +96,7 @@ struct ReluToFusedBatchNorm : public OpRewritePattern<ReluOp> {
state.addOperands(batch_norm.getOperands());
if (side_input) state.operands.push_back(side_input);
state.addTypes(batch_norm.getResultTypes());
state.addAttributes(batch_norm.getAttrs());
state.addAttributes(batch_norm->getAttrs());
Operation *op = rewriter.createOperation(state);
rewriter.replaceOp(batch_norm, op->getResults());

View File

@ -387,7 +387,7 @@ void MoveTransposeAfter(Operation* op, SmallVector<Operation*, 8>* work_list,
// Maybe add Transpose nodes for layout dependent results
// (or reuse existing transposes).
OpBuilder builder(op);
builder.setInsertionPoint(op);
builder.setInsertionPointAfter(op);
for (unsigned idx : layout_dependent_results) {
OpResult result = op->getResult(idx);

View File

@ -1198,18 +1198,22 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
return failure();
}
// Verify that start_index_map and collapsed_slice_dims are both an iota
// with the same number of elements as the last dimension of start_indices.
// Verify that start_index_map and collapsed_slice_dims contains the same
// values.
auto start_index_map = gather_op.dimension_numbers().start_index_map();
auto collapsed_slice_dims =
gather_op.dimension_numbers().collapsed_slice_dims();
if (!IsIotaAttr(start_index_map, start_indices_type.getShape().back()) ||
!IsIotaAttr(collapsed_slice_dims,
start_indices_type.getShape().back())) {
// TODO(tberghammer): Transform start_indices to support non-standard
// start_index_maps.
if (start_index_map.getNumElements() !=
collapsed_slice_dims.getNumElements()) {
return rewriter.notifyMatchFailure(
gather_op, "unsupported start index map and/or collapsed slice dims");
gather_op,
"different size for start index map and collapsed slice dims");
}
for (auto c : collapsed_slice_dims) {
if (llvm::count(start_index_map, c) == 0) {
return rewriter.notifyMatchFailure(
gather_op, "collapsed slice dim isn't present in start index map");
}
}
// Verify that slice_sizes is 1 for the indexed dimensions and the full
@ -1217,7 +1221,7 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
auto slice_sizes = gather_op.slice_sizes();
int64_t index = 0;
for (int64_t s : slice_sizes.getValues<int64_t>()) {
if (index < start_indices_type.getShape().back()) {
if (llvm::count(start_index_map, index)) {
if (s != 1) {
return rewriter.notifyMatchFailure(gather_op,
"unsupported slice sizes");
@ -1242,6 +1246,25 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
++offset;
}
// Transpose the operand to handle non-iota start index map.
llvm::SmallVector<int64_t, 4> transpose_dimensions;
llvm::SmallVector<int64_t, 4> transpose_shape;
for (auto s : start_index_map) {
transpose_dimensions.push_back(s.getZExtValue());
transpose_shape.push_back(operand_type.getShape()[s.getZExtValue()]);
}
for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) {
if (llvm::count(start_index_map, i) == 0) {
transpose_dimensions.push_back(i);
transpose_shape.push_back(operand_type.getShape()[i]);
}
}
operand_type =
RankedTensorType::get(transpose_shape, operand_type.getElementType());
operand = rewriter.create<mhlo::TransposeOp>(
gather_op.getLoc(), operand_type, operand,
rewriter.getI64TensorAttr(transpose_dimensions));
rewriter.replaceOpWithNewOp<TF::GatherNdOp>(gather_op, result_type, operand,
start_indices);
return success();

View File

@ -67,7 +67,6 @@ LogicalResult LiftVariablesFromSession(
ModuleOp module, Session* session,
const SmallSet<StringRef, 4>& resource_names) {
OpBuilder builder(module.getBodyRegion());
MLIRContext* context = module.getContext();
if (!session) return module.emitOpError() << "no session provided";
@ -137,7 +136,7 @@ LogicalResult LiftVariablesFromSession(
ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
builder.create<tf_saved_model::GlobalTensorOp>(
NameLoc::get(builder.getIdentifier(name.str()), context),
NameLoc::get(builder.getIdentifier(name.str())),
builder.getStringAttr(name), tensor_attr,
TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr());
}

View File

@ -68,7 +68,7 @@ class ResourceAnalyzer {
public:
explicit ResourceAnalyzer(ModuleOp module) {
for (auto func : module.getOps<FuncOp>()) {
(void)AnalyzeFunc(func);
(void)AnalyzeRegion(func.getRegion());
}
}
@ -82,18 +82,18 @@ class ResourceAnalyzer {
}
private:
// Analyze the specified func for resource mutating operations, namely
// Analyze the specified region for resource mutating operations, namely
// TF::AssignVariableOp, if so, set the resource associated as "potentially
// written". Do this recursively across the chain of funcs via call or control
// flow ops.
// written". Do this recursively across the chain of regions via call or
// control flow ops.
// TODO(ashwinm): Move to iterative traversal.
LogicalResult AnalyzeFunc(FuncOp func) {
LogicalResult AnalyzeRegion(Region& region) {
// Avoid infinite recursion.
if (!discovered_.insert(func).second) {
if (!discovered_.insert(&region).second) {
return success();
}
func.walk([&](Operation* op) {
region.walk([&](Operation* op) {
if (isa<TF::ReadVariableOp, ReturnOp>(op)) {
return;
}
@ -103,23 +103,40 @@ class ResourceAnalyzer {
}
if (auto call = dyn_cast<CallOpInterface>(op)) {
if (auto func = dyn_cast<FuncOp>(call.resolveCallable())) {
PropagatePotentiallyWrittenUpFromCallee(func, call.getArgOperands());
PropagatePotentiallyWrittenUpFromCallee(func.getRegion(),
call.getArgOperands());
}
return;
}
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
for (auto callee : {if_op.then_function(), if_op.else_function()}) {
PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input());
PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
if_op.input());
}
return;
}
if (auto if_op = dyn_cast<TF::IfRegionOp>(op)) {
PropagatePotentiallyWrittenUpFromCallee(if_op.then_branch(),
if_op.getODSOperands(1));
PropagatePotentiallyWrittenUpFromCallee(if_op.else_branch(),
if_op.getODSOperands(1));
return;
}
if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
for (auto callee :
{while_op.cond_function(), while_op.body_function()}) {
PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input());
PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
while_op.input());
}
return;
}
if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
PropagatePotentiallyWrittenUpFromCallee(while_op.cond(),
while_op.input());
PropagatePotentiallyWrittenUpFromCallee(while_op.body(),
while_op.input());
return;
}
// For all other ops, we assume it mutates all resources it uses, so
// this errs on the side of being conservative. We should improve
// this by using either a property or a trait that clearly
@ -144,14 +161,14 @@ class ResourceAnalyzer {
});
}
// Given a FuncOp associated with the callee and operands from the
// Given a Region associated with the callee and operands from the
// corresponding callOp, propagate the potentially written decision to the
// callOp's operands, if the corresponding func's arguments are potentially
// callOp's operands, if the corresponding region's arguments are potentially
// written resources.
void PropagatePotentiallyWrittenUpFromCallee(
FuncOp func, Operation::operand_range propagate_to) {
(void)AnalyzeFunc(func);
for (auto t : llvm::zip(func.getArguments(), propagate_to)) {
Region& region, Operation::operand_range propagate_to) {
(void)AnalyzeRegion(region);
for (auto t : llvm::zip(region.getArguments(), propagate_to)) {
if (!IsResource(std::get<0>(t))) {
continue;
}
@ -172,8 +189,8 @@ class ResourceAnalyzer {
// Value: Information we know about that Value.
// Note that these Value's are in general in different functions.
DenseMap<Value, ResourceInfo> resource_infos_;
// The set of func's we already discovered.
DenseSet<FuncOp> discovered_;
// The set of regions we already discovered.
DenseSet<Region*> discovered_;
};
bool IsImmutable(GlobalTensorOp global_tensor,
@ -231,7 +248,7 @@ void MarkGlobalTensorsImmutable(
auto global_tensor = kv.first;
const auto& global_tensor_uses = kv.second;
if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
global_tensor.removeAttr("is_mutable");
global_tensor->removeAttr("is_mutable");
}
}
}

View File

@ -43,6 +43,11 @@ namespace TF {
// ops.
std::unique_ptr<OperationPass<FuncOp>> CreateDropWhileShapeInvariantPass();
// Creates a pass that drops `shape_invariant` attribute from While/WhileRegion
// ops within device cluster.
std::unique_ptr<OperationPass<FuncOp>>
CreateDropWhileShapeInvariantInDeviceClusterPass();
// Transforms functional control flow operations in the TensorFlow dialect to
// MLIR Control Flow Graph (CFG) form.
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();
@ -203,6 +208,15 @@ std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass();
// will replicate the tf.Const op once for each device.
std::unique_ptr<OperationPass<ModuleOp>> CreateConstantOpDeviceAssignmentPass();
// Populates the supplied passmanager with the passes required to export
// to TensorFlow Graph.
void AddGraphExportLoweringPasses(OpPassManager& pm);
// Returns pass that verifies whether all functions in module are of single
// tf_executor.graph and each tf_executor.island in tf_executor.graph only has a
// single op.
std::unique_ptr<OperationPass<ModuleOp>> CreateVerifySuitableForExportPass();
} // namespace TF
namespace tf_executor {

View File

@ -931,7 +931,7 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
while_op.getLoc(), body.getType().getResults(),
FilterRange<Value, OperandRange>(while_op.getOperands(),
resource_arg_uses),
while_op.getAttrs());
while_op->getAttrs());
// Prepare for AddLoadsStoresOutsideControlFlowOp().
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
arg_data_type_and_updated_output_index;
@ -1035,7 +1035,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<FuncOp> branches) {
FuncOp first_func = branches.front();
auto new_op =
builder.create<CaseOrIfOp>(op.getLoc(), first_func.getType().getResults(),
new_operands, op.getAttrs());
new_operands, op->getAttrs());
// Prepare for AddLoadsStoresOutsideControlFlowOp()
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
arg_data_type_and_updated_output_index;
@ -1179,7 +1179,7 @@ void UpdatePartitionedCallOpWithNewCallee(
FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
auto new_call = builder.create<CallOpType>(
call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
new_operands, call_op.getAttrs());
new_operands, call_op->getAttrs());
new_call->setAttr(
"f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
AddLoadsStoresOutsideControlFlowOp(

View File

@ -172,121 +172,124 @@ RankedTensorType DropFirstDimension(Type type) {
bool CanInferTensorListElementType(Value tensorlist,
Value initial_element_shape,
RankedTensorType* potential_element_type) {
DCOMMENT("CanInferTensorListElementType " << tensorlist << " with initial "
<< initial_element_shape);
// Verifies if the new element type has static shape and matches the potential
// type passed from caller. Updates the potential_element_type, if not defined
// yet.
auto verify_and_update_potential_element_type =
[&](RankedTensorType new_element_type) -> bool {
DCOMMENT("\t\tConsidering " << new_element_type << " with old "
<< *potential_element_type);
if (!new_element_type || !new_element_type.hasStaticShape()) return false;
if (!*potential_element_type) {
DCOMMENT("\t\tUpdating potential_element_type " << new_element_type);
*potential_element_type = new_element_type;
return true;
}
return *potential_element_type == new_element_type;
};
// TensorLists are semantically immutable. For example, TensorListSetItem
// takes a TensorList as input and produces a TensorList as output. So to
// traverse modifications to TensorList and verify that all elements written
// to it have the same shape, we need to follow use-def chain of ops that
// (conceptually) modify it i.e., ops that take an input TensorList and
// produce an output TensorList.
for (auto& use : tensorlist.getUses()) {
if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) {
auto element_type = push.tensor().getType().dyn_cast<RankedTensorType>();
if (!verify_and_update_potential_element_type(element_type)) return false;
if (!CanInferTensorListElementType(push.output_handle(),
initial_element_shape,
potential_element_type))
return false;
continue;
}
if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>(
use.getOwner())) {
// For scatter op we can get the element shape by dropping the first
// dimension of the input tensor.
RankedTensorType element_type =
DropFirstDimension(scatter.tensor().getType());
if (!verify_and_update_potential_element_type(element_type)) return false;
if (!CanInferTensorListElementType(scatter.output_handle(),
initial_element_shape,
potential_element_type))
return false;
continue;
}
if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) {
auto element_type =
set_item.item().getType().dyn_cast<RankedTensorType>();
if (!verify_and_update_potential_element_type(element_type)) return false;
if (!CanInferTensorListElementType(set_item.output_handle(),
initial_element_shape,
potential_element_type))
return false;
continue;
}
if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) {
if (!CanInferTensorListElementType(pop.output_handle(),
initial_element_shape,
potential_element_type))
return false;
continue;
}
if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) {
if (!CanInferTensorListElementType(resize.output_handle(),
initial_element_shape,
potential_element_type))
return false;
continue;
}
// WhileRegionOp can explicitly capture TensorList value to be used inside
// its regions. So we check the uses of corresponding block argument in each
// region and the use of TensorList returned using YieldOp.
if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) {
for (auto branch : while_region.getRegions()) {
if (!CanInferTensorListElementType(
branch->getArgument(use.getOperandNumber()),
initial_element_shape, potential_element_type))
return false;
}
continue;
}
if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
Operation* parent = yield->getParentOp();
if (!CanInferTensorListElementType(
parent->getResult(use.getOperandNumber()), initial_element_shape,
potential_element_type))
return false;
continue;
}
// Refining the tensor list element type might change the output of
// TensorListElementShape which is expected to be the originally assigned
// shape to TensorList init ops. So replace it with the original element
// shape value.
if (auto tl_element_shape =
dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
// If element types match, we can do a direct replacement.
if (getElementTypeOrSelf(tl_element_shape.getResult()) ==
getElementTypeOrSelf(initial_element_shape.getType())) {
tl_element_shape.replaceAllUsesWith(initial_element_shape);
} else {
OpBuilder b(use.getOwner());
auto cast_op = b.create<TF::CastOp>(
use.getOwner()->getLoc(), tl_element_shape.getResult().getType(),
initial_element_shape,
/*truncate=*/b.getBoolAttr(false));
tl_element_shape.replaceAllUsesWith(cast_op.getResult());
}
continue;
}
// Ignore ops that just consume a TensorList and do not output another
// TensorList.
if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op,
TensorListLengthOp, TensorListGetItemOp>(use.getOwner()))
continue;
std::stack<Value> worklist;
worklist.emplace(tensorlist);
// For any other unknown users of the TensorList, we are conservative and
// stop element shape inference.
return false;
while (!worklist.empty()) {
tensorlist = worklist.top();
worklist.pop();
// TensorLists are semantically immutable. For example, TensorListSetItem
// takes a TensorList as input and produces a TensorList as output. So to
// traverse modifications to TensorList and verify that all elements written
// to it have the same shape, we need to follow use-def chain of ops that
// (conceptually) modify it i.e., ops that take an input TensorList and
// produce an output TensorList.
for (auto& use : tensorlist.getUses()) {
if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) {
auto element_type =
push.tensor().getType().dyn_cast<RankedTensorType>();
if (!verify_and_update_potential_element_type(element_type))
return false;
worklist.emplace(push.output_handle());
continue;
}
if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>(
use.getOwner())) {
// For scatter op we can get the element shape by dropping the first
// dimension of the input tensor.
RankedTensorType element_type =
DropFirstDimension(scatter.tensor().getType());
if (!verify_and_update_potential_element_type(element_type))
return false;
worklist.emplace(scatter.output_handle());
continue;
}
if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) {
auto element_type =
set_item.item().getType().dyn_cast<RankedTensorType>();
DCOMMENT("\tTensorListSetItemOp " << element_type);
if (!verify_and_update_potential_element_type(element_type))
return false;
worklist.emplace(set_item.output_handle());
continue;
}
if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) {
worklist.emplace(pop.output_handle());
continue;
}
if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) {
worklist.emplace(resize.output_handle());
continue;
}
// WhileRegionOp can explicitly capture TensorList value to be used inside
// its regions. So we check the uses of corresponding block argument in
// each region and the use of TensorList returned using YieldOp.
if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) {
DCOMMENT("\tTL WhileRegion");
for (auto branch : while_region.getRegions())
worklist.emplace(branch->getArgument(use.getOperandNumber()));
continue;
}
if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
Operation* parent = yield->getParentOp();
worklist.emplace(parent->getResult(use.getOperandNumber()));
continue;
}
// TODO(jpienaar): This can be generalized.
if (isa<IdentityOp, IdentityNOp, StopGradientOp>(use.getOwner())) {
worklist.emplace(use.getOwner()->getResult(use.getOperandNumber()));
continue;
}
// Refining the tensor list element type might change the output of
// TensorListElementShape which is expected to be the originally assigned
// shape to TensorList init ops. So replace it with the original element
// shape value.
if (auto tl_element_shape =
dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
// If element types match, we can do a direct replacement.
if (getElementTypeOrSelf(tl_element_shape.getResult()) ==
getElementTypeOrSelf(initial_element_shape.getType())) {
tl_element_shape.replaceAllUsesWith(initial_element_shape);
} else {
OpBuilder b(use.getOwner());
auto cast_op = b.create<TF::CastOp>(
use.getOwner()->getLoc(), tl_element_shape.getResult().getType(),
initial_element_shape,
/*truncate=*/b.getBoolAttr(false));
tl_element_shape.replaceAllUsesWith(cast_op.getResult());
}
continue;
}
// Ignore ops that just consume a TensorList and do not output another
// TensorList.
if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op,
TensorListLengthOp, TensorListGetItemOp>(use.getOwner()))
continue;
// For any other unknown users of the TensorList, we are conservative and
// stop element shape inference.
DCOMMENT("TensorListType infer, unknown op " << *use.getOwner());
return false;
}
}
return true;
}
@ -778,8 +781,12 @@ bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
if (!element_type || !element_type.hasStaticShape()) return false;
}
if (!CanInferTensorListElementType(handle, initial_element_shape,
&element_type))
&element_type)) {
DCOMMENT("InferShapeForListInitOps " << op << " could not infer");
return false;
}
DCOMMENT("InferShapeForListInitOps " << op << " could be inferred "
<< element_type);
if (!element_type || !element_type.hasStaticShape()) return false;
auto variant_type = VariantType::get(element_type, op->getContext());
auto tensor_type = RankedTensorType::get({}, variant_type);
@ -1049,7 +1056,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
// The shape function of these ops sometimes does not propagate subtypes
// (handle shapes) for resource and variant types. We use a simple passthrough
// to make sure they are preserved in the output.
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp>(op)) {
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp, TF::ZerosLikeOp>(
op)) {
return RefineTypeForPassThroughOperands(op, op->getOperands(),
op->getResults());
}

View File

@ -204,7 +204,7 @@ LogicalResult HandleWhileOp(
}
auto new_while =
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
new_while_operands, while_op.getAttrs());
new_while_operands, while_op->getAttrs());
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
.isa<TF::ResourceType>()) {
@ -257,7 +257,7 @@ LogicalResult HandleIfOp(
}
auto new_if = OpBuilder(if_op).create<TF::IfOp>(
if_op.getLoc(), then_func.getType().getResults(), new_if_operands,
if_op.getAttrs());
if_op->getAttrs());
for (auto result : if_op.getResults()) {
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
continue;
@ -306,7 +306,7 @@ LogicalResult HandlePartitionedCallOp(
OpBuilder builder(call);
auto new_call = builder.create<CallOp>(
call.getLoc(), info.decomposed_callee.getType().getResults(),
new_operands, call.getAttrs());
new_operands, call->getAttrs());
new_call->setAttr(
"f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(info.decomposed_callee).getName()));

View File

@ -625,7 +625,7 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
OpBuilder builder(while_op);
auto new_while =
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
operands, while_op.getAttrs());
operands, while_op->getAttrs());
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
if (ta_arg_buffer_type(i)) {
while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
@ -692,7 +692,7 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
OpBuilder builder(if_op);
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
then_branch.getType().getResults(),
operands, if_op.getAttrs());
operands, if_op->getAttrs());
auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
auto retval = f.front().getTerminator()->getOperand(ret_ind);
auto arg = retval.dyn_cast<BlockArgument>();
@ -751,7 +751,7 @@ LogicalResult HandlePartitionedCallOp(
OpBuilder builder(call);
auto new_call = builder.create<CallOp>(
call.getLoc(), info.decomposed_callee.getType().getResults(),
new_operands, call.getAttrs());
new_operands, call->getAttrs());
new_call->setAttr(
"f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(info.decomposed_callee).getName()));

View File

@ -208,7 +208,7 @@ LogicalResult HandleWhileOp(
}
auto new_while =
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
new_while_operands, while_op.getAttrs());
new_while_operands, while_op->getAttrs());
for (const auto& entry : output_buffer_to_size) {
(*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
@ -268,7 +268,7 @@ LogicalResult HandleCaseOrIfOp(
FuncOp first_branch = branches.front();
auto new_op = OpBuilder(op).create<CaseOrIfOp>(
op.getLoc(), first_branch.getType().getResults(), new_operands,
op.getAttrs());
op->getAttrs());
for (const auto& entry : output_buffer_to_size) {
(*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
@ -329,7 +329,7 @@ LogicalResult HandleWhileRegionOp(
}
auto new_while = builder.create<TF::WhileRegionOp>(
while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(),
new_while_operands, while_op.getAttrs());
new_while_operands, while_op->getAttrs());
new_while.body().takeBody(body_region);
new_while.cond().takeBody(cond_region);
for (const auto& entry : output_buffer_to_size) {
@ -369,7 +369,7 @@ LogicalResult HandleIfRegionOp(
// Recreate the op.
auto new_op = OpBuilder(if_op).create<TF::IfRegionOp>(
if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(),
if_op.getOperand(), if_op.getAttrs());
if_op.getOperand(), if_op->getAttrs());
for (const auto& entry : output_buffer_to_size) {
(*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
@ -415,7 +415,7 @@ LogicalResult HandleCaseRegionOp(
auto new_op = OpBuilder(case_op).create<TF::CaseRegionOp>(
case_op.getLoc(),
first_branch->front().getTerminator()->getOperandTypes(),
case_op.getOperand(), case_op.getAttrs(), case_op.getNumRegions());
case_op.getOperand(), case_op->getAttrs(), case_op.getNumRegions());
for (const auto& entry : output_buffer_to_size) {
(*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
@ -456,7 +456,7 @@ LogicalResult HandlePartitionedCallOp(
OpBuilder builder(call);
auto new_call = builder.create<CallOp>(
call.getLoc(), info.decomposed_callee.getType().getResults(),
new_operands, call.getAttrs());
new_operands, call->getAttrs());
new_call->setAttr(
"f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(info.decomposed_callee).getName()));
@ -818,7 +818,7 @@ LogicalResult DecomposeTensorListOpsInternal(
decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
// TODO(yuanzx): Add a pass to remove identities in device computation.
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
} else if (auto list = llvm::dyn_cast<TF::EmptyTensorListOp>(&op)) {

View File

@ -761,3 +761,12 @@ func @cluster_oplist(%arg0: tensor<f32>, %arg1: tensor<i32>) -> tensor<i32> {
];
}
def VerifySuitableForExportPass : Pass<"tf-verify-for-export", "ModuleOp"> {
let summary = "Verify module is suitable for export back to TF Graph";
let description = [{
Verifies whether all functions in module are of single tf_executor.graph and
each tf_executor.island in tf_executor.graph only has a single op.
}];
let constructor = "TF::CreateVerifySuitableForExportPass()";
}

View File

@ -582,7 +582,7 @@ LogicalResult FormClustersInBlock(
cluster->setAttrs(
cluster_metadata->second.getDictionary(cluster.getContext()));
// Exclude `num_replicas` as cluster should be replicated if necessary.
cluster.removeAttr(kNumReplicasAttr);
cluster->removeAttr(kNumReplicasAttr);
}
return success();

View File

@ -225,24 +225,26 @@ void PropagateDevicesToResults(
}
struct TPUDevicePropagation
: public PassWrapper<TPUDevicePropagation, FunctionPass> {
void runOnFunction() override;
: public PassWrapper<TPUDevicePropagation, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
void TPUDevicePropagation::runOnFunction() {
FuncOp func = getFunction();
if (!IsSupportedGraph(func)) return;
void TPUDevicePropagation::runOnOperation() {
ModuleOp m = getOperation();
m.walk([&](FuncOp func) {
if (!IsSupportedGraph(func)) return;
llvm::DenseMap<Value, llvm::StringRef> value_to_device;
PropagateDevicesFromArguments(func, value_to_device);
auto graph = llvm::cast<tf_executor::GraphOp>(func.front().front());
PropagateDevicesInGraph(graph, value_to_device);
PropagateDevicesToResults(func, graph.GetFetch(), value_to_device);
llvm::DenseMap<Value, llvm::StringRef> value_to_device;
PropagateDevicesFromArguments(func, value_to_device);
auto graph = llvm::cast<tf_executor::GraphOp>(func.front().front());
PropagateDevicesInGraph(graph, value_to_device);
PropagateDevicesToResults(func, graph.GetFetch(), value_to_device);
});
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDevicePropagationPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDevicePropagationPass() {
return std::make_unique<TPUDevicePropagation>();
}

View File

@ -326,7 +326,7 @@ tf_device::ClusterOp UpdateClusterResults(
auto new_cluster = builder->create<tf_device::ClusterOp>(
cluster.getLoc(), new_cluster_result_types,
/*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
/*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
new_cluster.body().takeBody(cluster.body());
auto operand_not_in_cluster = [&](OpOperand& operand) {
@ -400,7 +400,7 @@ void RemoveClusterAliasedOutputs(OpBuilder* builder,
builder->setInsertionPoint(cluster);
auto new_cluster = builder->create<tf_device::ClusterOp>(
cluster.getLoc(), new_cluster_result_types,
/*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
/*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
new_cluster.body().takeBody(cluster.body());
new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);

View File

@ -94,13 +94,13 @@ LogicalResult ReorderReplicateAndPartitionedInputs(
for (const auto& operands_per_replica : operands_per_replica_per_core) {
auto replicate_op = builder.create<TF::TPUReplicatedInputOp>(
replicated_input.getLoc(), replicated_input.getType(),
operands_per_replica, replicated_input.getAttrs());
operands_per_replica, replicated_input->getAttrs());
operands_per_core.push_back(replicate_op);
}
auto pi = builder.create<TF::TPUPartitionedInputOp>(
first_partitioned_input.getLoc(), replicated_input.getType(),
operands_per_core, first_partitioned_input.getAttrs());
operands_per_core, first_partitioned_input->getAttrs());
replicated_input.replaceAllUsesWith(pi.output());
return success();
}

View File

@ -0,0 +1,43 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
namespace mlir {
namespace TF {
namespace {
class VerifySuitableForExportPass
: public VerifySuitableForExportPassBase<VerifySuitableForExportPass> {
public:
void runOnOperation() override {
if (failed(tensorflow::VerifyExportSuitable(getOperation())))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateVerifySuitableForExportPass() {
return std::make_unique<VerifySuitableForExportPass>();
}
} // namespace TF
} // namespace mlir

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
@ -46,8 +47,10 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -74,9 +77,6 @@ using stream_executor::port::StatusOr;
namespace {
constexpr char kInvalidExecutorGraphMsg[] =
"Functions must be of a single Graph with single op Islands: ";
constexpr char kDeviceAttr[] = "tf.device";
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
@ -91,52 +91,6 @@ class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
}
};
// Checks functions in module are of single tf_executor.graph and each
// tf_executor.island in tf_executor.graph only has a single op.
Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) {
Status status = Status::OK();
module.walk([&](mlir::FuncOp function) {
if (!llvm::hasSingleElement(function)) {
status = errors::FailedPrecondition(
kInvalidExecutorGraphMsg,
"only single block functions are supported.");
return mlir::WalkResult::interrupt();
}
auto block = function.front().without_terminator();
auto graph = llvm::dyn_cast<mlir::tf_executor::GraphOp>(block.begin());
if (!graph) {
status = errors::FailedPrecondition(
kInvalidExecutorGraphMsg,
"first op in function is not a tf_executor.graph.");
return mlir::WalkResult::interrupt();
}
if (!hasSingleElement(block)) {
status = errors::FailedPrecondition(
kInvalidExecutorGraphMsg,
"function does not only contain a single tf_executor.graph.");
return mlir::WalkResult::interrupt();
}
for (Operation& op : graph.GetBody()) {
auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
if (!island) continue;
if (!island.WrapsSingleOp()) {
status = errors::FailedPrecondition(
kInvalidExecutorGraphMsg,
"tf_executor.island must perfectly wrap a single op.");
return mlir::WalkResult::interrupt();
}
}
return mlir::WalkResult::advance();
});
return status;
}
// Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is
// returned.
Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) {
@ -159,10 +113,10 @@ class Exporter {
// Converts a given FuncOp to a FunctionDef and adds it to the function
// definition library
static Status ConvertLibFunction(const GraphExportConfig& configs,
const Dialect* tf_dialect,
mlir::FuncOp function,
FunctionDefLibrary* flib);
static Status ConvertLibFunction(
const GraphExportConfig& configs, const Dialect* tf_dialect,
mlir::FuncOp function, FunctionDefLibrary* flib,
llvm::SmallDenseSet<mlir::FuncOp>& visited_functions);
// Converts the given FuncOp to a Graph. The arguments and returns of
// function are added to the graph with special op names kArgOp and kRetOp.
// Later on, this graph can be converted a function definition and added to
@ -170,6 +124,7 @@ class Exporter {
static StatusOr<std::unique_ptr<Graph>> Convert(
const GraphExportConfig& configs, const Dialect* tf_dialect,
mlir::FuncOp function, FunctionDefLibrary* flib,
llvm::SmallDenseSet<mlir::FuncOp>& visited_functions,
absl::flat_hash_set<Node*>* control_ret_nodes);
private:
@ -451,6 +406,7 @@ Status Exporter::GetControlRetNodes(
StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
const GraphExportConfig& configs, const Dialect* tf_dialect,
mlir::FuncOp function, FunctionDefLibrary* flib,
llvm::SmallDenseSet<mlir::FuncOp>& visited_functions,
absl::flat_hash_set<Node*>* control_ret_nodes) {
mlir::Block& block = function.front();
@ -550,7 +506,8 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
function->getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
name);
if (func != nullptr) {
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib,
visited_functions));
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
}
return Status::OK();
@ -610,22 +567,21 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
return graph;
}
Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
const Dialect* tf_dialect,
mlir::FuncOp function,
FunctionDefLibrary* flib) {
// First look for the function in the current function library. If found,
// nothing needs to be done.
OpRegistry empty_registry;
FunctionLibraryDefinition flib_def(&empty_registry, *flib);
Status Exporter::ConvertLibFunction(
const GraphExportConfig& configs, const Dialect* tf_dialect,
mlir::FuncOp function, FunctionDefLibrary* flib,
llvm::SmallDenseSet<mlir::FuncOp>& visited_functions) {
// Return early if the function has already been exported.
bool is_new_function = visited_functions.insert(function).second;
if (!is_new_function) return Status::OK();
auto function_name = function.getName().str();
if (flib_def.Find(function_name)) return Status::OK();
// TODO(fengliuai): use a small flib_def to reduce overhead
absl::flat_hash_set<Node*> control_ret_nodes;
TF_ASSIGN_OR_RETURN(auto sub_graph,
Exporter::Convert(configs, tf_dialect, function, flib,
&control_ret_nodes));
visited_functions, &control_ret_nodes));
const auto control_ret = [&](const Node* n) -> absl::optional<string> {
return control_ret_nodes.contains(n)
? absl::make_optional<string>(n->name())
@ -652,8 +608,8 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
auto grad_func =
function->getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
attr.getValue());
TF_RETURN_IF_ERROR(
ConvertLibFunction(configs, tf_dialect, grad_func, flib));
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, grad_func, flib,
visited_functions));
GradientDef grad;
grad.set_function_name(function_name);
grad.set_gradient_func(grad_func.getName().str());
@ -709,26 +665,32 @@ Status Exporter::Convert(mlir::ModuleOp module,
mlir::Identifier::get("main", module.getContext());
absl::optional<mlir::FuncOp> entry_func;
FunctionDefLibrary flib;
llvm::SmallDenseSet<mlir::FuncOp> visited_functions;
auto tf_dialect = module.getContext()->getLoadedDialect("tf");
for (auto function : module.getOps<mlir::FuncOp>()) {
if (function.isExternal())
return errors::FailedPrecondition("External functions not supported");
if (function.getName() == entry_func_id) {
if (function.getName() == entry_func_id &&
!configs.export_entry_func_to_flib) {
entry_func.emplace(function);
} else {
TF_RETURN_IF_ERROR(
ConvertLibFunction(configs, tf_dialect, function, &flib));
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, function,
&flib, visited_functions));
}
}
if (!entry_func.has_value())
return errors::FailedPrecondition("entry function `main` must be present");
if (!configs.export_entry_func_to_flib) {
if (!entry_func.has_value())
return errors::FailedPrecondition(
"entry function `main` must be present");
// Updates the graph and the function library definition.
TF_ASSIGN_OR_RETURN(
*graph, Exporter::Convert(configs, tf_dialect, entry_func.value(),
&flib, visited_functions, control_ret_nodes));
}
// Updates the graph and the function library definition.
TF_ASSIGN_OR_RETURN(
*graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib,
control_ret_nodes));
for (auto& func_def : flib.function()) {
TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
}
@ -744,8 +706,10 @@ Status ConvertMlirToGraph(mlir::ModuleOp module,
std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def,
absl::flat_hash_set<Node*>* control_ret_nodes) {
TF_RETURN_IF_ERROR(HasSingleGraphSingleOpIslandsFunctions(module));
return Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes);
mlir::StatusScopedDiagnosticHandler sh(module.getContext());
if (failed(VerifyExportSuitable(module))) return sh.ConsumeStatus();
return sh.Combine(
Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes));
}
Status ConvertMlirToGraph(mlir::ModuleOp module,
@ -761,8 +725,16 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
mlir::ModuleOp module, const GraphExportConfig& configs) {
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
FunctionDefLibrary());
auto graph = absl::make_unique<Graph>(flib_def);
std::unique_ptr<Graph> graph;
TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
// If the entry function is exported to flib, then no graph is constructed.
// Construct one in that case.
if (configs.export_entry_func_to_flib) {
graph = std::make_unique<Graph>(OpRegistry::Global());
}
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib_def.ToProto()));
auto graphdef = absl::make_unique<GraphDef>();
graph->ToGraphDef(graphdef.get());
if (!configs.export_library) graphdef->clear_library();
@ -784,8 +756,9 @@ stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
FunctionDef* function_def) {
Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
FunctionDefLibrary flib;
TF_RETURN_IF_ERROR(
Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib));
llvm::SmallDenseSet<mlir::FuncOp> visited_functions;
TF_RETURN_IF_ERROR(Exporter::ConvertLibFunction(configs, tf_dialect, func,
&flib, visited_functions));
for (auto& func_def : flib.function()) {
if (func_def.signature().name() == func.getName()) {
*function_def = func_def;

View File

@ -113,6 +113,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/stream_executor/lib/statusor.h"
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
@ -1291,17 +1292,23 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
if (name_and_value.first == "_input_shapes") {
auto& list = name_and_value.second.list();
auto& signature = func_def->signature();
if (list.shape_size() != signature.input_arg_size()) {
// Some models have "_input_shapes" attribute, but with its value empty
if (list.shape_size() > 0 &&
list.shape_size() != signature.input_arg_size()) {
return errors::FailedPrecondition(
"Number of input arguments must be equal to the length of "
"_input_shapes attribute in function '",
StringRefToView(func_name), "'.");
}
for (int i = 0; i < list.shape_size(); i++) {
for (int i = 0; i < signature.input_arg_size(); i++) {
auto& input_arg = signature.input_arg(i);
auto& array_info = specs.inputs[input_arg.name()];
array_info.imported_dtype = input_arg.type();
array_info.shape = list.shape(i);
// set to unranked for empty "_input_shapes" attribute
if (list.shape_size() > 0)
array_info.shape = list.shape(i);
else
array_info.shape.set_unknown_rank(true);
}
}
}
@ -1661,7 +1668,7 @@ mlir::Location ImporterBase::GetLocation(const Node& node) {
// If there are no locations in the stack trace, fall back to just a
// NameLoc with no child.
if (locations.empty()) return mlir::NameLoc::get(name_loc_id, context_);
if (locations.empty()) return mlir::NameLoc::get(name_loc_id);
// Use the front FileLineColLoc to generate a NameLoc.
mlir::Location node_name_loc =
@ -1984,8 +1991,28 @@ Status ImporterBase::ConvertNode(const Node& node) {
}
const auto& node_def = node.def();
// NodeDef can contain partial TF device names. In such cases, canonicalize
// it. Note that in current TF, placer will place full device name to each
// node.
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullName(node_def.device(), &parsed_name)) {
return errors::InvalidArgument(
"Op ", op_name, " has invalid device name: ", node_def.device());
}
// Keep the parsed name untouched if the device name is empty.
if (!node_def.device().empty()) {
if (!parsed_name.has_type) {
parsed_name.type = "CPU";
parsed_name.has_type = true;
}
if (!parsed_name.has_id) {
parsed_name.id = 0;
parsed_name.has_id = true;
}
}
result.attributes.push_back(builder_.getNamedAttr(
"device", builder_.getStringAttr(std::string(node_def.device()))));
"device", builder_.getStringAttr(
DeviceNameUtils::ParsedNameToString(parsed_name))));
// Map user function calls to LegacyCall ops and add the user function name
// as an attribute.
@ -2165,7 +2192,8 @@ class GraphDefImporter : public ImporterBase {
mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
llvm::StringRef func_name);
llvm::StringRef func_name,
std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
private:
explicit GraphDefImporter(
@ -2206,11 +2234,11 @@ class GraphDefImporter : public ImporterBase {
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
const GraphImportConfig& specs, llvm::StringRef func_name) {
const GraphImportConfig& specs, llvm::StringRef func_name,
std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
LoadImporterDialects(*context);
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
NameUniquifier function_name_uniquifier(flib_def);
GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
@ -3499,12 +3527,14 @@ class SavedModelSignatureDefImporterLite {
const std::string& name,
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
const std::vector<std::string> control_outputs);
const std::vector<std::string> control_outputs,
std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
// Moves the functions in `sub_module` to `module_` and skips the duplicate
// functions.
Status MoveConvertedFunctionsToModule(absl::string_view name,
mlir::ModuleOp sub_module);
Status MoveConvertedFunctionsToModule(
absl::string_view name, mlir::ModuleOp sub_module,
const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
GraphImportConfig::InputArrays ParseInputArrays(
llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
@ -3545,15 +3575,25 @@ SavedModelSignatureDefImporterLite::ConvertAssets() {
}
Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
absl::string_view name, mlir::ModuleOp sub_module) {
absl::string_view name, mlir::ModuleOp sub_module,
const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
mlir::Builder builder(sub_module.getContext());
mlir::SymbolTable sub_module_symbol_table(sub_module);
// Functions originally from graphdef library might have a different name
// after conversion, we build the set of the converted names
absl::flat_hash_set<std::string> original_func_mlir_names;
for (const auto& kv : tf_name_to_mlir_name)
original_func_mlir_names.insert(kv.second);
// Prefix private functions with the unique signature name, so that it cannot
// collide with private functions used in the other signatures.
for (auto func : sub_module.getOps<mlir::FuncOp>()) {
if (mlir::tf_saved_model::IsExported(func)) continue;
// Skip the original functions from graphdef library
if (original_func_mlir_names.count(func.sym_name().str())) continue;
std::string new_sym_name = absl::StrCat(name, "/", func.sym_name().str());
if (mlir::failed(sub_module_symbol_table.replaceAllSymbolUses(
func, new_sym_name, sub_module)))
@ -3567,7 +3607,7 @@ Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
// Copy all functions used by this signature to the final MLIR module.
for (auto func : sub_module.getOps<mlir::FuncOp>()) {
DCHECK(symbol_table_.lookup(func.sym_name()) == nullptr);
// The insert here is a NO-OP if the function already exists.
symbol_table_.insert(func.clone());
}
@ -3586,13 +3626,15 @@ Status SavedModelSignatureDefImporterLite::ConvertInitializer(
inputs.push_back({asset.tensor_name, tensor_info});
}
TF_ASSIGN_OR_RETURN(auto sub_module, ConvertGraph(target_node_name, inputs,
{}, {target_node_name}));
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
TF_ASSIGN_OR_RETURN(auto sub_module,
ConvertGraph(target_node_name, inputs, {},
{target_node_name}, tf_name_to_mlir_name));
mlir::SymbolTable sub_symbol_table(*sub_module);
auto init_func_op = sub_symbol_table.lookup<mlir::FuncOp>(target_node_name);
init_func_op.removeAttr("tf.entry_function");
init_func_op->removeAttr("tf.entry_function");
mlir::OpBuilder builder(module_->getBodyRegion());
@ -3612,7 +3654,8 @@ Status SavedModelSignatureDefImporterLite::ConvertInitializer(
"__tf_saved_model_session_initializer_", target_node_name)}));
// Move the converted functions to top level MLIR module.
return MoveConvertedFunctionsToModule(target_node_name, *sub_module);
return MoveConvertedFunctionsToModule(target_node_name, *sub_module,
tf_name_to_mlir_name);
}
StatusOr<mlir::OwningModuleRef>
@ -3620,7 +3663,8 @@ SavedModelSignatureDefImporterLite::ConvertGraph(
const std::string& name,
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
const std::vector<std::string> control_outputs) {
const std::vector<std::string> control_outputs,
std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
VLOG(1) << "Importing Signature: " << name;
GraphImportConfig specs;
@ -3634,7 +3678,7 @@ SavedModelSignatureDefImporterLite::ConvertGraph(
// Convert sub-graph to MLIR module.
return GraphDefImporter::Convert(module_->getContext(), *subgraph,
input_.debug_info(), subgraph->flib_def(),
specs, name);
specs, name, tf_name_to_mlir_name);
}
Status SavedModelSignatureDefImporterLite::ConvertSignature(
@ -3654,9 +3698,12 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature(
return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
});
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
// Convert sub-graph to MLIR module.
TF_ASSIGN_OR_RETURN(auto sub_module,
ConvertGraph(sig_def_key, inputs, outputs, {}));
TF_ASSIGN_OR_RETURN(
auto sub_module,
ConvertGraph(sig_def_key, inputs, outputs, {}, tf_name_to_mlir_name));
mlir::OpBuilder builder(sub_module->getBodyRegion());
// Find the FuncOp which corresponds to current SignatureDef.
@ -3682,7 +3729,8 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature(
}
// Move the converted functions to top level MLIR module.
return MoveConvertedFunctionsToModule(sig_def_key, *sub_module);
return MoveConvertedFunctionsToModule(sig_def_key, *sub_module,
tf_name_to_mlir_name);
}
GraphImportConfig::InputArrays
@ -3820,7 +3868,7 @@ class SavedModelSignatureDefImporter {
(*module)->setAttr("tf_saved_model.under_construction",
builder.getUnitAttr());
TF_RETURN_IF_ERROR(LiftVariables(bundle, *module));
module->removeAttr("tf_saved_model.under_construction");
(*module)->removeAttr("tf_saved_model.under_construction");
return module;
}
@ -3895,8 +3943,9 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
const_cast<FunctionLibraryDefinition*>(&flib_def),
specs.restrict_functionalization_to_tpu_nodes));
}
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
/*func_name=*/"main");
/*func_name=*/"main", tf_name_to_mlir_name);
}
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
@ -3908,9 +3957,10 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
specs.graph_as_function = true;
for (const auto* control_ret_node : fbody->control_ret_nodes)
specs.control_outputs.push_back(control_ret_node->name());
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
flib_def, specs,
fbody->fdef.signature().name());
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
return GraphDefImporter::Convert(
context, *fbody->graph, dummy_debug_info, flib_def, specs,
fbody->fdef.signature().name(), tf_name_to_mlir_name);
}
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(

View File

@ -75,6 +75,9 @@ struct GraphExportConfig {
bool export_library = true;
// Whether to export debug original node name in the GraphDef.
bool export_debug_info = true;
// Whether to export the entry function to function library instead of the
// graph.
bool export_entry_func_to_flib = false;
};
// Parses the command line flag strings to the specification of nodes in

View File

@ -21,6 +21,7 @@ limitations under the License.
using llvm::cl::opt;
// Import options.
// NOLINTNEXTLINE
opt<std::string> input_arrays(
"tf-input-arrays", llvm::cl::desc("Input tensor names, separated by ','"),
@ -115,3 +116,11 @@ opt<bool> enable_shape_inference(
"tf-enable-shape-inference-on-import",
llvm::cl::desc("Enable shape inference on import (temporary)"),
llvm::cl::init(false));
// Export options.
// NOLINTNEXTLINE
opt<bool> export_entry_func_to_flib(
"tf-export-entry-func-to-flib",
llvm::cl::desc(
"Export entry function to function library instead of graph"),
llvm::cl::init(false));

View File

@ -26,6 +26,7 @@ limitations under the License.
// Please see the implementation file for documentation of these options.
// Import options.
extern llvm::cl::opt<std::string> input_arrays;
extern llvm::cl::opt<std::string> input_dtypes;
extern llvm::cl::opt<std::string> input_shapes;
@ -42,4 +43,7 @@ extern llvm::cl::opt<bool> upgrade_legacy;
// TODO(jpienaar): Temporary flag, flip default and remove.
extern llvm::cl::opt<bool> enable_shape_inference;
// Export options.
extern llvm::cl::opt<bool> export_entry_func_to_flib;
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_CL_H_

View File

@ -75,6 +75,7 @@ static LogicalResult MlirToGraphdefTranslateFunction(
// TODO(fengliuai): Add exporter flags.
tensorflow::GraphExportConfig confs;
confs.export_entry_func_to_flib = export_entry_func_to_flib;
StatusOr<std::unique_ptr<tensorflow::GraphDef>> graphdef_or(
tensorflow::ConvertMlirToGraphdef(module, confs));
if (!graphdef_or.status().ok()) {

View File

@ -100,15 +100,15 @@ struct WritableFileRawStream : public llvm::raw_ostream {
struct CrashReproducerStream : public mlir::PassManager::ReproducerStream {
CrashReproducerStream(llvm::StringRef name,
std::unique_ptr<WritableFile> file)
std::unique_ptr<llvm::raw_ostream> file)
: name(name), ostream(std::move(file)) {}
llvm::StringRef description() override { return name; }
raw_ostream& os() override { return ostream; }
raw_ostream& os() override { return *ostream; }
private:
std::string name;
WritableFileRawStream ostream;
std::unique_ptr<llvm::raw_ostream> ostream;
};
} // namespace
@ -225,25 +225,32 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) {
}
}
auto* env = tensorflow::Env::Default();
auto status = env->RecursivelyCreateDir(path);
if (!status.ok()) {
LOG(WARNING) << "cannot create directory '" + path +
"': " + status.error_message();
return;
}
if (path != "-") {
auto* env = tensorflow::Env::Default();
auto status = env->RecursivelyCreateDir(path);
if (!status.ok()) {
LOG(WARNING) << "cannot create directory '" + path +
"': " + status.error_message();
return;
}
path += "/mlir_reproducer_";
path += "/mlir_reproducer_";
if (!tensorflow::Env::Default()->CreateUniqueFileName(&path, ".mlir")) {
LOG(WARNING)
<< "cannot create unique filename, won't enable MLIR crash reproducer.";
return;
if (!tensorflow::Env::Default()->CreateUniqueFileName(&path, ".mlir")) {
LOG(WARNING) << "cannot create unique filename, won't enable MLIR crash "
"reproducer.";
return;
}
}
mlir::PassManager::ReproducerStreamFactory factory =
[path](std::string& error)
-> std::unique_ptr<mlir::PassManager::ReproducerStream> {
// Use the stderr stream.
if (path == "-")
return std::make_unique<CrashReproducerStream>(
"(stderr)", std::make_unique<LogInfoRawStream>());
// Try to open the file and generate a raw_ostream.
std::unique_ptr<WritableFile> file;
Status status = tensorflow::Env::Default()->NewWritableFile(path, &file);
@ -252,7 +259,8 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) {
"': ", status.error_message());
return nullptr;
}
return std::make_unique<CrashReproducerStream>(path, std::move(file));
return std::make_unique<CrashReproducerStream>(
path, std::make_unique<WritableFileRawStream>(std::move(file)));
};
pm.enableCrashReproducerGeneration(factory, /*genLocalReproducer=*/false);
}

View File

@ -177,7 +177,7 @@ LogicalResult InferReturnTypeComponentsForTFOp(
OpResultAsShapeFn op_result_as_shape_fn,
ResultElementTypeFn result_element_type_fn,
SmallVectorImpl<ShapedTypeComponents>& inferred_return_shapes) {
assert(op->getName().getDialect() ==
assert(op->getName().getDialectNamespace() ==
TensorFlowDialect::getDialectNamespace());
auto op_name_or =

View File

@ -0,0 +1,68 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
namespace tensorflow {
namespace {
constexpr char kInvalidExecutorGraphMsg[] =
"functions must be of a single Graph with single op Islands: ";
} // namespace
mlir::LogicalResult VerifyExportSuitable(mlir::ModuleOp module) {
mlir::WalkResult result = module.walk([&](mlir::FuncOp function) {
if (!llvm::hasSingleElement(function)) {
function.emitError(kInvalidExecutorGraphMsg)
<< "only single block functions are supported";
return mlir::WalkResult::interrupt();
}
auto block = function.front().without_terminator();
auto graph = llvm::dyn_cast<mlir::tf_executor::GraphOp>(block.begin());
if (!graph) {
block.begin()->emitError(kInvalidExecutorGraphMsg)
<< "first op in function is not a tf_executor.graph";
return mlir::WalkResult::interrupt();
}
if (!hasSingleElement(block)) {
function.emitError(kInvalidExecutorGraphMsg)
<< "function does not only contain a single tf_executor.graph";
return mlir::WalkResult::interrupt();
}
for (mlir::Operation& op : graph.GetBody()) {
auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
if (!island) continue;
if (!island.WrapsSingleOp()) {
island.emitError(kInvalidExecutorGraphMsg)
<< "tf_executor.island must perfectly wrap a single op";
return mlir::WalkResult::interrupt();
}
}
return mlir::WalkResult::advance();
});
return mlir::failure(result.wasInterrupted());
}
} // namespace tensorflow

View File

@ -0,0 +1,30 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
namespace tensorflow {
// Returns whether all functions in module are of single tf_executor.graph and
// each tf_executor.island in tf_executor.graph only has a single op.
mlir::LogicalResult VerifyExportSuitable(mlir::ModuleOp module);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_

View File

@ -22,10 +22,10 @@ from __future__ import print_function
import os
import sys
from absl import app
from tensorflow.compiler.mlir.tfr.python.composite import Composite
from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
FLAGS = flags.FLAGS

View File

@ -22,13 +22,13 @@ from __future__ import print_function
import os
import sys
from absl import app
from tensorflow.compiler.mlir.tfr.python import composite
from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gen_array_ops as array_ops
from tensorflow.python.platform import app
from tensorflow.python.platform import flags

View File

@ -22,6 +22,7 @@ from __future__ import print_function
import os
import sys
from absl import app
import tensorflow as tf
@ -31,7 +32,6 @@ from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
Composite = composite.Composite

View File

@ -23,13 +23,13 @@ from __future__ import print_function
import os
import sys
from absl import app
import tensorflow as tf
from tensorflow.compiler.mlir.tfr.python import composite
from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
Composite = composite.Composite

View File

@ -441,7 +441,7 @@ def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">,
// non-derived ones.
llvm::StringSet<> getDefinedAttributeNames() {
llvm::StringSet<> all_attrs;
for (auto& attr : getAttrs()) {
for (auto& attr : (*this)->getAttrs()) {
all_attrs.insert(attr.first.strref());
}
for (const auto& operand : llvm::enumerate(getType().getInputs())) {

Some files were not shown because too many files have changed in this diff Show More