Merge branch 'master' of https://github.com/tensorflow/tensorflow into PR6
This commit is contained in:
commit
e2bd33bc62
31
.bazelrc
31
.bazelrc
@ -50,7 +50,6 @@
|
||||
# Feature and Third party library support options:
|
||||
# xla: Build TF with XLA
|
||||
# tpu: Build TF with TPU support
|
||||
# using_cuda: CUDA is available to build system.
|
||||
# cuda: Build with full cuda support.
|
||||
# rocm: Build with AMD GPU support (rocm).
|
||||
# mkl: Enable full mkl support.
|
||||
@ -92,6 +91,7 @@
|
||||
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
|
||||
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
|
||||
# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds.
|
||||
# release_gpu_linux_cuda_11_2: Toolchain and CUDA options for CUDA 11.2 Linux GPU builds.
|
||||
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
|
||||
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
|
||||
|
||||
@ -190,23 +190,22 @@ build:mkl_aarch64 -c opt
|
||||
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
build:using_cuda --action_env TF_NEED_CUDA=1
|
||||
build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
||||
build:cuda_base --@local_config_cuda//:enable_cuda
|
||||
build:cuda_base --action_env TF_NEED_CUDA=1
|
||||
build:cuda_base --crosstool_top=@local_config_cuda//crosstool:toolchain
|
||||
|
||||
# Enable the mlir generated GPU kernels only for cuda builds.
|
||||
build --define=tensorflow_enable_mlir_generated_gpu_kernels=0
|
||||
# This is a more specific option, so it takes precedence over the line above for cuda builds.
|
||||
build:using_cuda --define=tensorflow_enable_mlir_generated_gpu_kernels=1
|
||||
build:cuda_base --define=tensorflow_enable_mlir_generated_gpu_kernels=1
|
||||
|
||||
# This config refers to building CUDA op kernels with nvcc.
|
||||
build:cuda --config=using_cuda
|
||||
build:cuda --define=using_cuda_nvcc=true
|
||||
build:cuda --config=cuda_base
|
||||
build:cuda --@local_config_cuda//:cuda_compiler=nvcc
|
||||
|
||||
# This config refers to building CUDA op kernels with clang.
|
||||
build:cuda_clang --config=using_cuda
|
||||
build:cuda_clang --define=using_cuda_clang=true
|
||||
build:cuda_clang --define=using_clang=true
|
||||
build:cuda_clang --config=cuda_base
|
||||
build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
|
||||
build:cuda_clang --action_env TF_CUDA_CLANG=1
|
||||
|
||||
# dbg config, as a shorthand for '--config=opt -c dbg'
|
||||
@ -430,6 +429,9 @@ build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platf
|
||||
build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
|
||||
build:rbe_linux_cuda_base --config=rbe_linux
|
||||
build:rbe_linux_cuda_base --@local_config_cuda//:enable_cuda
|
||||
build:rbe_linux_cuda_base --@local_config_cuda//:cuda_compiler=nvcc
|
||||
# TODO(csigg): those probably don't do anything because cuda_config is remote.
|
||||
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
|
||||
build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10
|
||||
build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7
|
||||
@ -438,7 +440,6 @@ build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
|
||||
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
|
||||
|
||||
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
@ -455,7 +456,6 @@ build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
@ -623,6 +623,13 @@ build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cud
|
||||
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10"
|
||||
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7"
|
||||
|
||||
|
||||
build:release_gpu_linux_cuda_11_2 --config=release_gpu_linux
|
||||
build:release_gpu_linux_cuda_11_2 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2"
|
||||
build:release_gpu_linux_cuda_11_2 --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain
|
||||
build:release_gpu_linux_cuda_11_2 --action_env=TF_CUDA_VERSION="11.2"
|
||||
build:release_gpu_linux_cuda_11_2 --action_env=TF_CUDNN_VERSION="8.1"
|
||||
|
||||
# Address sanitizer
|
||||
# CC=clang bazel build --config asan
|
||||
build:asan --strip=never
|
||||
|
10
RELEASE.md
10
RELEASE.md
@ -119,6 +119,10 @@
|
||||
`tf.config.experimental.get_memory_usage` in favor of this new function.
|
||||
* Extended `tf.config.experimental.enable_tensor_float_32_execution` to
|
||||
control Tensor-Float-32 evaluation in RNNs.
|
||||
* Added a 'experimental_payloads' field to tf.errors.OpError and
|
||||
its subclasses to support more detailed error reporting.
|
||||
This is inspired from Abseil Status payloads:
|
||||
https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h
|
||||
|
||||
* `tf.summary`:
|
||||
* New `tf.summary.graph` allows manual write of TensorFlow graph
|
||||
@ -140,7 +144,11 @@
|
||||
`max_batch_size`. Previously, we issued a warning when the value of
|
||||
`rewriter_config_template` is not None. We issued an error when the
|
||||
value of `is_dynamic_op` is not True. We didn't use the value for
|
||||
`max_batch_size` for building TensorRT engines.
|
||||
`max_batch_size` for building TensorRT engines. Add parameters
|
||||
`use_dynamic_shape` to enable dynamic shape support. The default is to
|
||||
disable dynamic shape support. Add `dynamic_shape_profile_strategy`
|
||||
for selecting a dynamic shape profile strategy. The default is profile
|
||||
strategy is `Range`.
|
||||
* Issue a warning when function get_tensorrt_rewriter_config is used.
|
||||
|
||||
* TF XLA
|
||||
|
102
tensorflow/BUILD
102
tensorflow/BUILD
@ -197,7 +197,19 @@ config_setting(
|
||||
|
||||
config_setting(
|
||||
name = "windows",
|
||||
values = {"cpu": "x64_windows"},
|
||||
# Internal builds query the target OS.
|
||||
# copybara:uncomment flag_values = {"//tools/cpp:cc_target_os": "windows"},
|
||||
# OSS builds query the CPU type.
|
||||
values = {"cpu": "x64_windows"}, # copybara:comment
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "msvc_cl_debug",
|
||||
values = {
|
||||
"compiler": "msvc-cl",
|
||||
"compilation_mode": "dbg",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
@ -402,11 +414,12 @@ config_setting(
|
||||
|
||||
# Crosses between platforms and file system libraries not supported on those
|
||||
# platforms due to limitations in nested select() statements.
|
||||
config_setting(
|
||||
selects.config_setting_group(
|
||||
name = "with_cuda_support_windows_override",
|
||||
define_values = {"using_cuda_nvcc": "true"},
|
||||
values = {"cpu": "x64_windows"},
|
||||
visibility = ["//visibility:public"],
|
||||
match_all = [
|
||||
":using_cuda_nvcc",
|
||||
":windows",
|
||||
],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
@ -470,9 +483,46 @@ selects.config_setting_group(
|
||||
],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
# Config setting that is satisfied when TensorFlow is being built with CUDA
|
||||
# support through e.g. `--config=cuda` (or `--config=cuda_clang` in OSS).
|
||||
alias(
|
||||
name = "is_cuda_enabled",
|
||||
actual = if_oss(
|
||||
"@local_config_cuda//:is_cuda_enabled",
|
||||
"@local_config_cuda//cuda:using_clang",
|
||||
),
|
||||
)
|
||||
|
||||
# Config setting that is satisfied when CUDA device code should be compiled
|
||||
# with clang. It does not imply that CUDA support has been enabled.
|
||||
alias(
|
||||
name = "is_cuda_compiler_clang",
|
||||
actual = if_oss(
|
||||
"@local_config_cuda//:is_cuda_compiler_clang",
|
||||
"@local_config_cuda//cuda:TRUE",
|
||||
),
|
||||
)
|
||||
|
||||
# Config setting that is satisfied when CUDA device code should be compiled
|
||||
# with nvcc. It does not imply that CUDA support has been enabled.
|
||||
alias(
|
||||
name = "is_cuda_compiler_nvcc",
|
||||
actual = if_oss(
|
||||
"@local_config_cuda//:is_cuda_compiler_nvcc",
|
||||
"@local_config_cuda//cuda:FALSE",
|
||||
),
|
||||
)
|
||||
|
||||
# Config setting whether TensorFlow is built with CUDA support using clang.
|
||||
alias(
|
||||
name = "using_cuda_clang",
|
||||
define_values = {"using_cuda_clang": "true"},
|
||||
actual = "@local_config_cuda//cuda:using_clang",
|
||||
)
|
||||
|
||||
# Config setting whether TensorFlow is built with CUDA support using nvcc.
|
||||
alias(
|
||||
name = "using_cuda_nvcc",
|
||||
actual = "@local_config_cuda//cuda:using_nvcc",
|
||||
)
|
||||
|
||||
# Config setting to use in select()s to distinguish open source build from
|
||||
@ -493,12 +543,12 @@ bool_setting(
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
selects.config_setting_group(
|
||||
name = "using_cuda_clang_with_dynamic_build",
|
||||
define_values = {
|
||||
"using_cuda_clang": "true",
|
||||
"framework_shared_object": "true",
|
||||
},
|
||||
match_all = [
|
||||
":using_cuda_clang",
|
||||
":framework_shared_object",
|
||||
],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
@ -519,19 +569,6 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "using_cuda_nvcc",
|
||||
define_values = {"using_cuda_nvcc": "true"},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "using_cuda_nvcc_with_dynamic_build",
|
||||
define_values = {
|
||||
"using_cuda_nvcc": "true",
|
||||
"framework_shared_object": "true",
|
||||
},
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "build_oss_using_cuda_nvcc",
|
||||
match_all = [
|
||||
@ -559,15 +596,6 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag is defined for select statements that match both
|
||||
# on 'windows' and 'api_version_2'. In this case, bazel requires
|
||||
# having a flag which is a superset of these two.
|
||||
config_setting(
|
||||
name = "windows_and_api_version_2",
|
||||
define_values = {"tf_api_version": "2"},
|
||||
values = {"cpu": "x64_windows"},
|
||||
)
|
||||
|
||||
# This flag enables experimental MLIR support.
|
||||
config_setting(
|
||||
name = "with_mlir_support",
|
||||
@ -934,6 +962,11 @@ tf_cc_shared_object(
|
||||
"-z defs",
|
||||
"-Wl,--version-script,$(location //tensorflow:tf_version_script.lds)",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:msvc_cl_debug": [
|
||||
"/DEBUG:FASTLINK",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
per_os_targets = True,
|
||||
soversion = VERSION,
|
||||
@ -952,7 +985,6 @@ tf_cc_shared_object(
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:client_session",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
@ -283,7 +283,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c/experimental/gradients:not_differentiable",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -909,7 +908,6 @@ tf_cuda_cc_test(
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
@ -935,7 +933,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -974,7 +971,6 @@ tf_cc_test(
|
||||
":custom_device_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
#include "tensorflow/core/lib/monitoring/collection_registry.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
@ -79,6 +79,7 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status) {
|
||||
assert(0);
|
||||
break;
|
||||
}
|
||||
tf_status->status.ReplaceAllPayloads(status.GetAllPayloads());
|
||||
}
|
||||
|
||||
Status StatusFromTF_Status(const TF_Status* tf_status) {
|
||||
|
@ -24,6 +24,8 @@ namespace {
|
||||
TEST(StatusHelper, TestStatusHelper) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
Status cc_status(errors::InvalidArgument("some error"));
|
||||
cc_status.SetPayload("key1", "value1");
|
||||
cc_status.SetPayload("key2", "value2");
|
||||
Set_TF_Status_from_Status(s, cc_status);
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
|
||||
ASSERT_EQ(std::string("some error"), TF_Message(s));
|
||||
@ -32,6 +34,9 @@ TEST(StatusHelper, TestStatusHelper) {
|
||||
ASSERT_FALSE(another_cc_status.ok());
|
||||
ASSERT_EQ(std::string("some error"), another_cc_status.error_message());
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, another_cc_status.code());
|
||||
// Ensure the payloads are not lost during conversions
|
||||
ASSERT_EQ(cc_status.GetPayload("key1"), another_cc_status.GetPayload("key1"));
|
||||
ASSERT_EQ(cc_status.GetPayload("key2"), another_cc_status.GetPayload("key2"));
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
|
@ -6,7 +6,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"cc_library_with_android_deps",
|
||||
"tf_cc_binary",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"transitive_hdrs",
|
||||
@ -748,36 +747,6 @@ tf_gen_op_wrappers_cc(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tutorials_example_trainer",
|
||||
srcs = ["tutorials/example_trainer.cc"],
|
||||
copts = tf_copts(),
|
||||
linkopts = select({
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:macos": [
|
||||
"-lm",
|
||||
"-lpthread",
|
||||
],
|
||||
"//tensorflow:ios": [
|
||||
"-lm",
|
||||
"-lpthread",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"-lm",
|
||||
"-lpthread",
|
||||
"-lrt",
|
||||
],
|
||||
}),
|
||||
deps = [
|
||||
":cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "queue_runner",
|
||||
srcs = ["training/queue_runner.cc"],
|
||||
@ -856,7 +825,6 @@ transitive_hdrs(
|
||||
":queue_runner",
|
||||
":remote_fused_graph_ops",
|
||||
":scope",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/cc/saved_model:loader",
|
||||
"//tensorflow/cc/saved_model:reader",
|
||||
|
@ -1,40 +0,0 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "profiler_test",
|
||||
srcs = ["profiler_test.cc"],
|
||||
tags = [
|
||||
"no_gpu", # b/77649654
|
||||
"no_rocm", # stream level tracing not supported on ROCm
|
||||
],
|
||||
deps = [
|
||||
":profiler",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "profiler",
|
||||
srcs = ["profiler.cc"],
|
||||
hdrs = ["profiler.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler:protos_all_cc",
|
||||
"//tensorflow/core/profiler:tfprof_options",
|
||||
"//tensorflow/core/profiler/internal:tfprof_stats",
|
||||
],
|
||||
)
|
@ -1,57 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
|
||||
Profiler::Profiler(const GraphDef& graph) {
|
||||
std::unique_ptr<GraphDef> graph_ptr(new GraphDef());
|
||||
*graph_ptr = graph;
|
||||
stats_.reset(new TFStats(std::move(graph_ptr), nullptr, nullptr, nullptr));
|
||||
}
|
||||
|
||||
void Profiler::AddStep(int64 step, const RunMetadata& run_meta) {
|
||||
std::unique_ptr<RunMetadata> run_meta_ptr(new RunMetadata());
|
||||
*run_meta_ptr = run_meta;
|
||||
stats_->AddRunMeta(step, std::move(run_meta_ptr));
|
||||
}
|
||||
|
||||
GraphNodeProto Profiler::ProfileGraph(const Options& options) {
|
||||
stats_->BuildView(kCmds[1]);
|
||||
return stats_->ShowGraphNode(kCmds[1], options);
|
||||
}
|
||||
|
||||
GraphNodeProto Profiler::ProfileNameScope(const Options& options) {
|
||||
stats_->BuildView(kCmds[0]);
|
||||
return stats_->ShowGraphNode(kCmds[0], options);
|
||||
}
|
||||
|
||||
MultiGraphNodeProto Profiler::ProfileOperations(const Options& options) {
|
||||
stats_->BuildView(kCmds[3]);
|
||||
return stats_->ShowMultiGraphNode(kCmds[3], options);
|
||||
}
|
||||
|
||||
Status Profiler::SerializeToString(string* content) {
|
||||
if (!content) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"Cannot use null string pointer for SerializeToString.");
|
||||
}
|
||||
stats_->SerializeToString(content);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
@ -1,100 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_PROFILER_PROFILER_H_
|
||||
#define TENSORFLOW_CC_PROFILER_PROFILER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/profiler/internal/tfprof_stats.h"
|
||||
#include "tensorflow/core/profiler/tfprof_options.h"
|
||||
#include "tensorflow/core/profiler/tfprof_output.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
|
||||
/// @addtogroup core
|
||||
/// @{
|
||||
|
||||
/// A `Profiler` object lets the caller profile the execution of a graph.
|
||||
///
|
||||
/// Example:
|
||||
/// // First build a graph and run tracing.
|
||||
/// Scope root = Scope::NewRootScope();
|
||||
/// auto a = Placeholder(root, DT_INT32);
|
||||
/// auto c = Add(root, a, {41});
|
||||
///
|
||||
/// ClientSession session(root);
|
||||
/// std::vector<Tensor> outputs;
|
||||
/// RunOptions run_options;
|
||||
/// run_options.set_trace_level(RunOptions::FULL_TRACE);
|
||||
/// RunMetadata run_meta;
|
||||
/// Status s = session.Run(run_options, { {a, {1}} }, {c}, &outputs,
|
||||
/// &run_meta);
|
||||
/// if (!s.ok()) { ... }
|
||||
///
|
||||
/// // Then create profiler to do profiling.
|
||||
/// GraphDef graph;
|
||||
/// root.ToGraphDef(&graph);
|
||||
/// Profiler profiler(graph);
|
||||
/// profiler.AddStep(0, run_meta);
|
||||
/// Options opts = ... // TODO(xpan): Support option building API.
|
||||
/// MultiGraphNodeProto r = profiler.ProfileOperations(opts);
|
||||
///
|
||||
class Profiler {
|
||||
public:
|
||||
/// `graph` is the model's GraphDef.
|
||||
explicit Profiler(const GraphDef& graph);
|
||||
|
||||
/// Adds tracing information `run_meta` to profiler. A `run_meta` is
|
||||
/// generated by a TensorFlow session run call. `step` is the key
|
||||
/// to the `run_meta`. When calling ProfileXXX methods, caller can specify
|
||||
/// `step` in `options` to selectively profile the corresponding `run_meta`.
|
||||
/// Multiple different `run_meta` can be keyed by the same `step` in order
|
||||
/// to group them together.
|
||||
void AddStep(int64 step, const RunMetadata& run_meta);
|
||||
|
||||
/// Profiles the model by organizing nodes in graph structure.
|
||||
/// Each node is an op and the nodes are connected by the op inputs/outputs.
|
||||
GraphNodeProto ProfileGraph(const Options& options);
|
||||
|
||||
/// Profiles the model by organizing nodes in name scope structure.
|
||||
/// Each node is an op, and nodes are organized by the ops' name
|
||||
/// scope, similar to a file system tree.
|
||||
/// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2.
|
||||
GraphNodeProto ProfileNameScope(const Options& options);
|
||||
|
||||
/// Profiles the model by organizing nodes by operation types.
|
||||
/// Each node is an operation type (e.g. Conv2D or MatMul), containing all
|
||||
/// ops belonging to that type in the model.
|
||||
MultiGraphNodeProto ProfileOperations(const Options& options);
|
||||
|
||||
/// Serialize the profile content (ProfileProto) into a binary string,
|
||||
/// User can write the string to file for offline analysis by
|
||||
/// tfprof command-line tools or graphical user interface.
|
||||
Status SerializeToString(string* content);
|
||||
|
||||
private:
|
||||
std::unique_ptr<TFStats> stats_;
|
||||
};
|
||||
/// @}
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_PROFILER_PROFILER_H_
|
@ -1,177 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/default_device.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
|
||||
class ProfilerTest : public ::testing::Test {
|
||||
protected:
|
||||
ProfilerTest() {}
|
||||
};
|
||||
|
||||
GraphDef CreateGraphDef() {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
auto a = ops::Const<float>(root, {{3, 2}, {-1, 0}});
|
||||
|
||||
auto x = ops::Const(root.WithOpName("x"), {{1.f}, {1.f}});
|
||||
|
||||
auto y = ops::MatMul(root.WithOpName("y"), a, x);
|
||||
|
||||
auto y2 = ops::Square(root, y);
|
||||
|
||||
auto y2_sum = ops::Sum(root, y2, 0);
|
||||
|
||||
auto y_norm = ops::Sqrt(root, y2_sum);
|
||||
|
||||
auto y_div = ops::Div(root.WithOpName("y_normalized"), y, y_norm);
|
||||
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(root.ToGraphDef(&def));
|
||||
|
||||
return def;
|
||||
}
|
||||
|
||||
Options Default() {
|
||||
Options opts(1000, /* max_depth */
|
||||
0, /* min_bytes */
|
||||
0, /* min_peak_bytes */
|
||||
0, /* min_residual_bytes */
|
||||
0, /* min_output_bytes */
|
||||
0, /* min_micros */
|
||||
0, /* min_accelerator_micros */
|
||||
0, /* min_cpu_micros */
|
||||
0, /* min_params */
|
||||
0, /* min_float_ops */
|
||||
0, /* min_occurrence */
|
||||
0, /* step */
|
||||
"name", /* order_by */
|
||||
{".*"}, /* account_type_regexes */
|
||||
{".*"}, /* start_name_regexes */
|
||||
{}, /* trim_name_regexes */
|
||||
{".*"}, {}, /* hide_name_regexes */
|
||||
false, /* account_displayed_op_only */
|
||||
{"micros"}, /* select */
|
||||
{"none"}, /* output_type */
|
||||
{});
|
||||
return opts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* ExtractNode(const T& pb, const string& name) {
|
||||
if (pb.name() == name) {
|
||||
return &pb;
|
||||
}
|
||||
for (const T& c : pb.children()) {
|
||||
const T* ret = ExtractNode(c, name);
|
||||
if (ret) return ret;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TEST_F(ProfilerTest, Basics) {
|
||||
SessionOptions options;
|
||||
options.config.set_allow_soft_placement(true);
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
GraphDef def = CreateGraphDef();
|
||||
if (options.target.empty()) {
|
||||
graph::SetDefaultDevice("/gpu:0", &def);
|
||||
}
|
||||
|
||||
TF_CHECK_OK(session->Create(def));
|
||||
|
||||
Tensor x(DT_FLOAT, TensorShape({2, 1}));
|
||||
auto x_flat = x.flat<float>();
|
||||
x_flat.setRandom();
|
||||
Eigen::Tensor<float, 0, Eigen::RowMajor> inv_norm =
|
||||
x_flat.square().sum().sqrt().inverse();
|
||||
x_flat = x_flat * inv_norm();
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
RunOptions run_options;
|
||||
run_options.set_trace_level(RunOptions::FULL_TRACE);
|
||||
RunMetadata run_metadata;
|
||||
outputs.clear();
|
||||
|
||||
Profiler profiler(def);
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
TF_CHECK_OK(session->Run(run_options, {{"x", x}}, {"y:0", "y_normalized:0"},
|
||||
{}, &outputs, &run_metadata));
|
||||
profiler.AddStep(i, run_metadata);
|
||||
CHECK_EQ(size_t{2}, outputs.size());
|
||||
}
|
||||
|
||||
std::vector<DeviceAttributes> resp;
|
||||
TF_CHECK_OK(session->ListDevices(&resp));
|
||||
bool has_gpu = false;
|
||||
for (const auto& dev : resp) {
|
||||
if (dev.device_type() == "GPU") {
|
||||
has_gpu = true;
|
||||
}
|
||||
}
|
||||
|
||||
GraphNodeProto ret = profiler.ProfileNameScope(Default());
|
||||
const GraphNodeProto* matmul = ExtractNode(ret, "y");
|
||||
EXPECT_TRUE(matmul);
|
||||
EXPECT_GT(matmul->exec_micros(), 0);
|
||||
if (has_gpu) {
|
||||
EXPECT_GT(matmul->accelerator_exec_micros(), 0);
|
||||
} else {
|
||||
EXPECT_EQ(matmul->accelerator_exec_micros(), 0);
|
||||
}
|
||||
const GraphNodeProto* square = ExtractNode(ret, "Square");
|
||||
EXPECT_TRUE(square);
|
||||
EXPECT_GT(square->exec_micros(), 0);
|
||||
if (has_gpu) {
|
||||
EXPECT_GT(square->accelerator_exec_micros(), 0);
|
||||
} else {
|
||||
EXPECT_EQ(square->accelerator_exec_micros(), 0);
|
||||
}
|
||||
|
||||
Options opts2 = Default();
|
||||
opts2.output_type = "timeline";
|
||||
string timeline_file = io::JoinPath(testing::TmpDir(), "timeline");
|
||||
opts2.output_options["outfile"] = timeline_file;
|
||||
GraphNodeProto ret2 = profiler.ProfileGraph(opts2);
|
||||
string s;
|
||||
TF_CHECK_OK(ReadFileToString(Env::Default(), timeline_file + "_0", &s));
|
||||
EXPECT_TRUE(s.find("Square") != s.npos);
|
||||
|
||||
MultiGraphNodeProto ret3 = profiler.ProfileOperations(Default());
|
||||
const MultiGraphNodeProto* matmul2 = ExtractNode(ret3, "MatMul");
|
||||
EXPECT_TRUE(matmul2);
|
||||
EXPECT_GT(matmul2->exec_micros(), 0);
|
||||
if (has_gpu) {
|
||||
EXPECT_GT(matmul2->accelerator_exec_micros(), 0);
|
||||
} else {
|
||||
EXPECT_EQ(matmul2->accelerator_exec_micros(), 0);
|
||||
}
|
||||
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
@ -1,234 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/default_device.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
using tensorflow::string;
|
||||
using tensorflow::int32;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace example {
|
||||
|
||||
struct Options {
|
||||
int num_concurrent_sessions = 1; // The number of concurrent sessions
|
||||
int num_concurrent_steps = 10; // The number of concurrent steps
|
||||
int num_iterations = 100; // Each step repeats this many times
|
||||
bool use_gpu = false; // Whether to use gpu in the training
|
||||
};
|
||||
|
||||
// A = [3 2; -1 0]; x = rand(2, 1);
|
||||
// We want to compute the largest eigenvalue for A.
|
||||
// repeat x = y / y.norm(); y = A * x; end
|
||||
GraphDef CreateGraphDef() {
|
||||
// TODO(jeff,opensource): This should really be a more interesting
|
||||
// computation. Maybe turn this into an mnist model instead?
|
||||
Scope root = Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
// A = [3 2; -1 0]. Using Const<float> means the result will be a
|
||||
// float tensor even though the initializer has integers.
|
||||
auto a = Const<float>(root, {{3, 2}, {-1, 0}});
|
||||
|
||||
// x = [1.0; 1.0]
|
||||
auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}});
|
||||
|
||||
// y = A * x
|
||||
auto y = MatMul(root.WithOpName("y"), a, x);
|
||||
|
||||
// y2 = y.^2
|
||||
auto y2 = Square(root, y);
|
||||
|
||||
// y2_sum = sum(y2). Note that you can pass constants directly as
|
||||
// inputs. Sum() will automatically create a Const node to hold the
|
||||
// 0 value.
|
||||
auto y2_sum = Sum(root, y2, 0);
|
||||
|
||||
// y_norm = sqrt(y2_sum)
|
||||
auto y_norm = Sqrt(root, y2_sum);
|
||||
|
||||
// y_normalized = y ./ y_norm
|
||||
Div(root.WithOpName("y_normalized"), y, y_norm);
|
||||
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(root.ToGraphDef(&def));
|
||||
|
||||
return def;
|
||||
}
|
||||
|
||||
string DebugString(const Tensor& x, const Tensor& y) {
|
||||
CHECK_EQ(x.NumElements(), 2);
|
||||
CHECK_EQ(y.NumElements(), 2);
|
||||
auto x_flat = x.flat<float>();
|
||||
auto y_flat = y.flat<float>();
|
||||
// Compute an estimate of the eigenvalue via
|
||||
// (x' A x) / (x' x) = (x' y) / (x' x)
|
||||
// and exploit the fact that x' x = 1 by assumption
|
||||
Eigen::Tensor<float, 0, Eigen::RowMajor> lambda = (x_flat * y_flat).sum();
|
||||
return strings::Printf("lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]",
|
||||
lambda(), x_flat(0), x_flat(1), y_flat(0), y_flat(1));
|
||||
}
|
||||
|
||||
void ConcurrentSteps(const Options* opts, int session_index) {
|
||||
// Creates a session.
|
||||
SessionOptions options;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
GraphDef def = CreateGraphDef();
|
||||
if (options.target.empty()) {
|
||||
graph::SetDefaultDevice(opts->use_gpu ? "/device:GPU:0" : "/cpu:0", &def);
|
||||
}
|
||||
|
||||
TF_CHECK_OK(session->Create(def));
|
||||
|
||||
// Spawn M threads for M concurrent steps.
|
||||
const int M = opts->num_concurrent_steps;
|
||||
std::unique_ptr<thread::ThreadPool> step_threads(
|
||||
new thread::ThreadPool(Env::Default(), "trainer", M));
|
||||
|
||||
for (int step = 0; step < M; ++step) {
|
||||
step_threads->Schedule([&session, opts, session_index, step]() {
|
||||
// Randomly initialize the input.
|
||||
Tensor x(DT_FLOAT, TensorShape({2, 1}));
|
||||
auto x_flat = x.flat<float>();
|
||||
x_flat.setRandom();
|
||||
Eigen::Tensor<float, 0, Eigen::RowMajor> inv_norm =
|
||||
x_flat.square().sum().sqrt().inverse();
|
||||
x_flat = x_flat * inv_norm();
|
||||
|
||||
// Iterations.
|
||||
std::vector<Tensor> outputs;
|
||||
for (int iter = 0; iter < opts->num_iterations; ++iter) {
|
||||
outputs.clear();
|
||||
TF_CHECK_OK(
|
||||
session->Run({{"x", x}}, {"y:0", "y_normalized:0"}, {}, &outputs));
|
||||
CHECK_EQ(size_t{2}, outputs.size());
|
||||
|
||||
const Tensor& y = outputs[0];
|
||||
const Tensor& y_norm = outputs[1];
|
||||
// Print out lambda, x, and y.
|
||||
std::printf("%06d/%06d %s\n", session_index, step,
|
||||
DebugString(x, y).c_str());
|
||||
// Copies y_normalized to x.
|
||||
x = y_norm;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Delete the threadpool, thus waiting for all threads to complete.
|
||||
step_threads.reset(nullptr);
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
|
||||
void ConcurrentSessions(const Options& opts) {
|
||||
// Spawn N threads for N concurrent sessions.
|
||||
const int N = opts.num_concurrent_sessions;
|
||||
|
||||
// At the moment our Session implementation only allows
|
||||
// one concurrently computing Session on GPU.
|
||||
CHECK_EQ(1, N) << "Currently can only have one concurrent session.";
|
||||
|
||||
thread::ThreadPool session_threads(Env::Default(), "trainer", N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
session_threads.Schedule(std::bind(&ConcurrentSteps, &opts, i));
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace example
|
||||
} // end namespace tensorflow
|
||||
|
||||
namespace {
|
||||
|
||||
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||
int32* dst) {
|
||||
if (absl::ConsumePrefix(&arg, flag) && absl::ConsumePrefix(&arg, "=")) {
|
||||
char extra;
|
||||
return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||
bool* dst) {
|
||||
if (absl::ConsumePrefix(&arg, flag)) {
|
||||
if (arg.empty()) {
|
||||
*dst = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (arg == "=true") {
|
||||
*dst = true;
|
||||
return true;
|
||||
} else if (arg == "=false") {
|
||||
*dst = false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
tensorflow::example::Options opts;
|
||||
std::vector<char*> unknown_flags;
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
if (string(argv[i]) == "--") {
|
||||
while (i < argc) {
|
||||
unknown_flags.push_back(argv[i]);
|
||||
++i;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (ParseInt32Flag(argv[i], "--num_concurrent_sessions",
|
||||
&opts.num_concurrent_sessions) ||
|
||||
ParseInt32Flag(argv[i], "--num_concurrent_steps",
|
||||
&opts.num_concurrent_steps) ||
|
||||
ParseInt32Flag(argv[i], "--num_iterations", &opts.num_iterations) ||
|
||||
ParseBoolFlag(argv[i], "--use_gpu", &opts.use_gpu)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Unknown flag: %s\n", argv[i]);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Passthrough any unknown flags.
|
||||
int dst = 1; // Skip argv[0]
|
||||
for (char* f : unknown_flags) {
|
||||
argv[dst++] = f;
|
||||
}
|
||||
argv[dst++] = nullptr;
|
||||
argc = static_cast<int>(unknown_flags.size() + 1);
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
tensorflow::example::ConcurrentSessions(opts);
|
||||
}
|
@ -297,6 +297,7 @@ cc_library(
|
||||
# Public visibility is needed for external TF/XLA backends.
|
||||
visibility = ["//visibility:public"],
|
||||
deps = XLA_DEVICE_DEPS + [":xla_compilation_cache"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -342,6 +343,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
@ -355,7 +357,9 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/protobuf:for_core_protos_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -292,6 +292,37 @@ MlirCommonFlags* GetMlirCommonFlags() {
|
||||
return mlir_flags;
|
||||
}
|
||||
|
||||
ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
|
||||
absl::optional<const ConfigProto> config_proto) {
|
||||
// TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs
|
||||
// can enable/disable the graph via tf_mlir_enable_mlir_bridge.
|
||||
auto tf_mlir_enable_mlir_bridge =
|
||||
GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||
if (tf_mlir_enable_mlir_bridge !=
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) {
|
||||
return tf_mlir_enable_mlir_bridge;
|
||||
}
|
||||
|
||||
// If a ConfigProto was not passed in, we can assume the caller is
|
||||
// checking if TF2 graph should have the bridge enabled / disabled. In that
|
||||
// case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to
|
||||
// return UNSPECIFIED here.
|
||||
if (!config_proto.has_value()) {
|
||||
return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
|
||||
}
|
||||
|
||||
// TF1 graphs that do override Session's ConfigProto and set
|
||||
// ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not
|
||||
// update tf_mlir_enable_mlir_bridge so check their values.
|
||||
|
||||
// ConfigProto's enable_mlir_bridge defaults to false so only respect it
|
||||
// when it is true.
|
||||
if (config_proto.value().experimental().enable_mlir_bridge()) {
|
||||
return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
}
|
||||
return config_proto.value().experimental().mlir_bridge_rollout();
|
||||
}
|
||||
|
||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
@ -156,6 +157,11 @@ GetIntroduceFloatingPointJitterPassFlags();
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags();
|
||||
|
||||
// Returns the effective MLIR bridge rollout state based on the flags and the
|
||||
// optional configuration.
|
||||
ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
|
||||
absl::optional<const ConfigProto> config_proto);
|
||||
|
||||
// Appends the flag definitions associated with
|
||||
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
|
||||
//
|
||||
|
@ -621,3 +621,6 @@ func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
|
||||
return %1 : tensor<8xi32>
|
||||
}
|
||||
```
|
||||
### `-tf-verify-for-export`: Verify module is suitable for export back to TF Graph
|
||||
Verifies whether all functions in module are of single tf_executor.graph and
|
||||
each tf_executor.island in tf_executor.graph only has a single op.
|
||||
|
@ -708,4 +708,49 @@ def HLOClient_BroadcastSelectOp : HLOClient_Op<
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLOClient_MinimumBroadcastShapesOp :
|
||||
HLOClient_Op<"minimum_broadcast_shapes", [NoSideEffect]> {
|
||||
string summary = "Minimizes the rank of two or more shapes to be broadcasted";
|
||||
|
||||
string description = [{
|
||||
Given two or more 1D tensors representing shapes, returns one 1D tensor for
|
||||
each operand, where operand `i` corresponds to output `i`.
|
||||
|
||||
The returned tensors have the property that they specify a shape which is a
|
||||
reshape of the corresponding input shape, and the broadcasted output shape
|
||||
(using shape::BroadcastOp) of the returned shapes is a reshape of the
|
||||
broadcasted output shape of the input shapes. Among all possibilities with
|
||||
this property, the one is chosen which minimizes the rank of each returned
|
||||
shape.
|
||||
|
||||
The general idea of this op is that it can be used for ops which have a
|
||||
broadcasting semantic to operate on shapes with a possibly smaller rank
|
||||
while preserving equivalence of the computed values. After computing the
|
||||
result of the op using reshaped operands, the result can be reshaped to the
|
||||
result that would have been originally computed.
|
||||
|
||||
Here is an example with two input shapes:
|
||||
|
||||
```mlir
|
||||
chlo.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1],
|
||||
[1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3]
|
||||
```
|
||||
|
||||
The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the
|
||||
broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are
|
||||
reshapes of each other, and also each output is a reshape of the
|
||||
corresponding input.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes);
|
||||
let results = (outs Variadic<1DTensorOf<[Index]>>:$results);
|
||||
|
||||
let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)";
|
||||
|
||||
}
|
||||
|
||||
#endif // CHLO_OPS
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
@ -888,6 +888,11 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate",
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r) {
|
||||
return succeeded(mlir::verifyCompatibleShapes(l, r));
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
|
||||
|
@ -337,6 +337,26 @@ static LogicalResult Verify(ConstantLikeOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MinimumBroadcastShapesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static LogicalResult Verify(MinimumBroadcastShapesOp op) {
|
||||
// Check that the number of operands matches the number of outputs.
|
||||
unsigned result_shapes_count = op.results().size();
|
||||
unsigned operand_shapes_count = op.shapes().size();
|
||||
if (operand_shapes_count != result_shapes_count) {
|
||||
return op.emitOpError()
|
||||
<< "number of operand shapes (" << operand_shapes_count
|
||||
<< ") does not match number of result shapes ("
|
||||
<< result_shapes_count << ")";
|
||||
}
|
||||
if (operand_shapes_count < 2) {
|
||||
return op.emitOpError() << "number of operand shapes ("
|
||||
<< operand_shapes_count << ") should be >= 2";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConstantLikeOp::inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
|
@ -51,6 +51,7 @@ struct ChloLegalizeToHloPass
|
||||
conversionTarget.addLegalDialect<
|
||||
MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect,
|
||||
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
|
||||
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
|
||||
|
||||
if (broadcast_only_) {
|
||||
chlo::PopulateChloBroadcastingPatterns(&getContext(),
|
||||
|
@ -424,7 +424,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
buffer_args.push_back(InsertAlloc(loc, result, &rewriter));
|
||||
}
|
||||
auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
|
||||
op.getAttrs());
|
||||
op->getAttrs());
|
||||
|
||||
// Copy over the operations inside the region.
|
||||
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
|
||||
@ -607,8 +607,9 @@ struct HloLegalizeToLhlo
|
||||
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
||||
populateFuncOpTypeConversionPattern(patterns, &context, converter);
|
||||
populateCallOpTypeConversionPattern(patterns, &context, converter);
|
||||
populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
|
||||
patterns, &context, converter);
|
||||
populateBranchOpInterfaceTypeConversionPattern(patterns, &context,
|
||||
converter);
|
||||
populateReturnOpTypeConversionPattern(patterns, &context, converter);
|
||||
populateEliminateBufferizeMaterializationsPatterns(&context, converter,
|
||||
patterns);
|
||||
|
||||
|
@ -517,14 +517,16 @@ class HloDynamicBroadcastInDimConverter
|
||||
auto shape_type = shape.getType().cast<RankedTensorType>();
|
||||
int64_t result_rank = shape_type.getDimSize(0);
|
||||
|
||||
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type) return failure();
|
||||
|
||||
SmallVector<Value, 2> dyn_dims;
|
||||
Location loc = op.getLoc();
|
||||
for (int i = 0; i < result_rank; ++i) {
|
||||
if (!result_type.isDynamicDim(i)) continue;
|
||||
Value index = rewriter.create<ConstantIndexOp>(loc, i);
|
||||
dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
|
||||
}
|
||||
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type) return failure();
|
||||
|
||||
int64_t nloops = result_type.getRank();
|
||||
Value init = rewriter.create<linalg::InitTensorOp>(
|
||||
@ -1146,8 +1148,7 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
||||
return dyn_shape;
|
||||
}
|
||||
|
||||
template <typename InputElType, int input_bit_width, typename OutputElType,
|
||||
int output_bit_width, DotOperationType op_type, typename LinalgOp>
|
||||
template <DotOperationType op_type, typename LinalgOp>
|
||||
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||
public:
|
||||
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
|
||||
@ -1157,28 +1158,13 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
||||
return failure();
|
||||
}
|
||||
if (GetDotOperationType(op) != op_type) return failure();
|
||||
|
||||
mhlo::DotOp::Adaptor adaptor(args);
|
||||
|
||||
auto lhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
auto rhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
|
||||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_el_type = output_type.getElementType();
|
||||
if (!output_el_type.isa<OutputElType>() ||
|
||||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (GetDotOperationType(op) != op_type) return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
|
||||
@ -1205,8 +1191,6 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
|
||||
return dyn_shape;
|
||||
}
|
||||
|
||||
template <typename InputElType, int input_bit_width, typename OutputElType,
|
||||
int output_bit_width, typename LinalgOp>
|
||||
class DotGeneralOpOnTensorsConversion
|
||||
: public OpConversionPattern<mhlo::DotGeneralOp> {
|
||||
public:
|
||||
@ -1245,23 +1229,10 @@ class DotGeneralOpOnTensorsConversion
|
||||
}
|
||||
|
||||
mhlo::DotGeneralOp::Adaptor adaptor(args);
|
||||
auto lhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
auto rhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
|
||||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_el_type = output_type.getElementType();
|
||||
if (!output_el_type.isa<OutputElType>() ||
|
||||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_el_type = output_type.getElementType();
|
||||
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
|
||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
|
||||
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
||||
@ -1269,7 +1240,7 @@ class DotGeneralOpOnTensorsConversion
|
||||
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
|
||||
Value zero_tensor =
|
||||
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||
Operation* linalg_op = rewriter.create<LinalgOp>(
|
||||
Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
|
||||
loc, /*resultTensorTypes=*/TypeRange{op.getType()},
|
||||
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
||||
/*outputBuffers=*/ValueRange{zero_tensor});
|
||||
@ -1476,9 +1447,14 @@ struct NormalConvOpOnTensorsConversion
|
||||
|
||||
// The output shape is N spatial_dims F.
|
||||
SmallVector<Value, 8> dyn_sizes;
|
||||
for (int64_t i = 0, e = rank - 1; i < e; ++i) {
|
||||
if (!result_type.isDynamicDim(i)) continue;
|
||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, i));
|
||||
if (result_type.isDynamicDim(0)) {
|
||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, 0));
|
||||
}
|
||||
for (int64_t i = 1, e = rank - 1; i < e; ++i) {
|
||||
if (result_type.isDynamicDim(i)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected output spatial dims to be static shapes");
|
||||
}
|
||||
}
|
||||
if (result_type.isDynamicDim(rank - 1)) {
|
||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
|
||||
@ -1702,49 +1678,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
SliceConverter<mhlo::SliceOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>,
|
||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulI8I8I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
linalg::MatvecI8I8I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
DotOperationType::kVectorDot,
|
||||
linalg::DotI8I8I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulI16I16I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
linalg::MatvecI16I16I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
DotOperationType::kVectorDot,
|
||||
linalg::DotI16I16I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulI32I32I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
linalg::MatvecI32I32I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
DotOperationType::kVectorDot,
|
||||
linalg::DotI32I32I32Op>,
|
||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulOp>,
|
||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
DotOpOnTensorsConversion<DotOperationType::kMatrixVector,
|
||||
linalg::MatvecOp>,
|
||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
DotOperationType::kVectorDot, linalg::DotOp>,
|
||||
DotGeneralOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
linalg::BatchMatmulI8I8I32Op>,
|
||||
DotGeneralOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
linalg::BatchMatmulI16I16I32Op>,
|
||||
DotGeneralOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
linalg::BatchMatmulI32I32I32Op>,
|
||||
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
linalg::BatchMatmulOp>,
|
||||
DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
|
||||
DotGeneralOpOnTensorsConversion,
|
||||
NormalConvOpOnTensorsConversion,
|
||||
ReduceOnTensorsConversion,
|
||||
PadOpOnTensorsConversion>(context);
|
||||
|
@ -126,7 +126,7 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||
Type flatResultTy =
|
||||
RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
|
||||
Value flatResult =
|
||||
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
|
||||
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op->getAttrs());
|
||||
|
||||
// Restore original shape.
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
|
||||
@ -192,7 +192,7 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||
rhs_is_scalar ? rhs : reshaped};
|
||||
Value computed = rewriter.create<ChloOpTy>(
|
||||
loc, TypeRange{RankedTensorType::get({-1}, result_element_type)},
|
||||
new_operands, op.getAttrs());
|
||||
new_operands, op->getAttrs());
|
||||
|
||||
// Reshape the result back into an unranked tensor.
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
||||
@ -223,9 +223,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
}
|
||||
|
||||
static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op,
|
||||
Value value, int targeted_rank) {
|
||||
Value shape, int targeted_rank) {
|
||||
auto loc = op.getLoc();
|
||||
Value shape = builder.create<shape::ShapeOfOp>(loc, value);
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
@ -246,6 +245,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder,
|
||||
ChloOpTy op,
|
||||
ValueRange operands,
|
||||
ValueRange operand_shapes,
|
||||
int targeted_rank) {
|
||||
auto loc = op.getLoc();
|
||||
SmallVector<Value, 2> reshaped_operands;
|
||||
@ -253,10 +253,12 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
|
||||
targeted_rank, RankedTensorType::kDynamicSize);
|
||||
|
||||
for (Value operand : operands) {
|
||||
for (auto it : llvm::zip(operands, operand_shapes)) {
|
||||
Value operand, shape;
|
||||
std::tie(operand, shape) = it;
|
||||
// Handle shape broadcasting and inference.
|
||||
Value extended_operand_casted =
|
||||
createBroadcastToKnownRank(if_builder, op, operand, targeted_rank);
|
||||
createBroadcastToKnownRank(if_builder, op, shape, targeted_rank);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of
|
||||
// elements)
|
||||
@ -278,7 +280,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
auto result_type =
|
||||
RankedTensorType::get(dynamic_dimensions, result_element_type);
|
||||
Value result = if_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, reshaped_operands, op.getAttrs());
|
||||
loc, ArrayRef<Type>{result_type}, reshaped_operands, op->getAttrs());
|
||||
Value reshaped_result = if_builder.create<tensor::CastOp>(
|
||||
loc, UnrankedTensorType::get(result_element_type), result);
|
||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||
@ -290,13 +292,37 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
ValueRange operands) {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Find the larger rank of the operands.
|
||||
// Get the minimum broadcast shapes of the operands.
|
||||
SmallVector<Value> shapes;
|
||||
shapes.reserve(operands.size());
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value greater_rank;
|
||||
for (Value operand : operands) {
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
auto broadcast_shape = rewriter.create<shape::BroadcastOp>(
|
||||
loc, extent_tensor_type, shapes, nullptr);
|
||||
SmallVector<Type> result_types(shapes.size(), extent_tensor_type);
|
||||
auto reduced_shapes =
|
||||
rewriter
|
||||
.create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes)
|
||||
.results();
|
||||
SmallVector<Value> reshaped_operands;
|
||||
reshaped_operands.reserve(operands.size());
|
||||
for (auto it : llvm::zip(operands, reduced_shapes)) {
|
||||
Value operand;
|
||||
Value reduced_shape;
|
||||
std::tie(operand, reduced_shape) = it;
|
||||
auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, operand.getType(), operand, reduced_shape);
|
||||
reshaped_operands.push_back(reshaped_operand);
|
||||
}
|
||||
|
||||
// Find the largest rank of the operands.
|
||||
Value greater_rank;
|
||||
for (Value shape : reduced_shapes) {
|
||||
Value rank =
|
||||
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
|
||||
if (!greater_rank) {
|
||||
@ -314,17 +340,19 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
rewriter, op, greater_rank, 1);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1);
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||
reduced_shapes, 1);
|
||||
|
||||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
constexpr int kMaxRankSpecialization = 6;
|
||||
constexpr int kMaxRankSpecialization = 5;
|
||||
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
||||
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
else_builder, op, greater_rank, i);
|
||||
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, operands, i);
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||
reduced_shapes, i);
|
||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||
}
|
||||
@ -336,12 +364,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
kMaxRankSpecialization),
|
||||
"Input for dynamic binary op lowering was of a rank greater than " +
|
||||
std::to_string(kMaxRankSpecialization));
|
||||
// Add the rank 6 specialization to the innermost else block.
|
||||
createRankSpecializedBroadcastAndOp(else_builder, op, operands,
|
||||
kMaxRankSpecialization);
|
||||
// Add the rank 5 specialization to the innermost else block.
|
||||
createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands,
|
||||
reduced_shapes, kMaxRankSpecialization);
|
||||
|
||||
// Return the result of the outermost if statement.
|
||||
return if_op.getResult(0);
|
||||
// Return the reshaped result of the outermost if statement.
|
||||
auto result = if_op.getResult(0);
|
||||
auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, result.getType(), result, broadcast_shape);
|
||||
return reshaped_result;
|
||||
}
|
||||
};
|
||||
|
||||
@ -386,7 +417,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||
op.getAttrs());
|
||||
op->getAttrs());
|
||||
Value extended_if_lhs_scalar_result =
|
||||
extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result,
|
||||
shape_of_lhs, shape_of_rhs);
|
||||
@ -409,7 +440,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||
loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs);
|
||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||
op.getAttrs());
|
||||
op->getAttrs());
|
||||
Value extended_if_rhs_scalar_result =
|
||||
extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result,
|
||||
shape_of_lhs, shape_of_rhs);
|
||||
@ -497,16 +528,17 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
|
||||
struct TransformUnrankedHloPass
|
||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
||||
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
|
||||
shape::ShapeDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
// Setup conversion target.
|
||||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
||||
shape::ShapeDialect, scf::SCFDialect,
|
||||
tensor::TensorDialect>();
|
||||
target.addLegalDialect<chlo::HloClientDialect, mhlo::MhloDialect,
|
||||
StandardOpsDialect, shape::ShapeDialect,
|
||||
scf::SCFDialect, tensor::TensorDialect>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
||||
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
||||
|
26
tensorflow/compiler/mlir/hlo/tests/chlo_ops.mlir
Normal file
26
tensorflow/compiler/mlir/hlo/tests/chlo_ops.mlir
Normal file
@ -0,0 +1,26 @@
|
||||
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @minimum_broadcast_shapes
|
||||
func @minimum_broadcast_shapes(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>)
|
||||
-> (tensor<?xindex>, tensor<?xindex>) {
|
||||
%0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs :
|
||||
tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
|
||||
return %0, %1 : tensor<?xindex>, tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @minimum_broadcast_shapes_mismatch_operand_and_result_count(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>) {
|
||||
// expected-error @+1{{number of operand shapes (2) does not match number of result shapes (1)}}
|
||||
%0 = chlo.minimum_broadcast_shapes %lhs, %rhs :
|
||||
tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @minimum_broadcast_shapes_one_operand(%arg: tensor<?xindex>) {
|
||||
// expected-error @+1{{number of operand shapes (1) should be >= 2}}
|
||||
%0 = chlo.minimum_broadcast_shapes %arg : tensor<?xindex> -> tensor<?xindex>
|
||||
return
|
||||
}
|
@ -954,6 +954,28 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor<?xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> ()>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
|
||||
// CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex>
|
||||
func @dynamic_broadcast_in_dim(%shape: tensor<2xindex>) -> tensor<?x32xf32> {
|
||||
%cst = mhlo.constant dense<0x7F800000> : tensor<f32>
|
||||
%result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) {
|
||||
broadcast_dimensions = dense<> : tensor<0xi64>
|
||||
} : (tensor<f32>, tensor<2xindex>) -> tensor<?x32xf32>
|
||||
return %result : tensor<?x32xf32>
|
||||
}
|
||||
// CHECK: [[CST:%.*]] = constant
|
||||
// CHECK: [[INIT:%.*]] = linalg.init_tensor
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK-SAME: ins([[CST]] : tensor<f32>) outs([[INIT]] : tensor<?x32xf32>)
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_matmul(%arg0: tensor<2x3xf32>,
|
||||
%arg1: tensor<3x?xf32>) -> tensor<2x?xf32> {
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>,
|
||||
@ -982,7 +1004,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul_i8_i8_i32
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||
|
||||
@ -1000,7 +1022,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul_i16_i16_i32
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||
|
||||
@ -1018,7 +1040,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul_i32_i32_i32
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||
|
||||
@ -1109,7 +1131,7 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
|
||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.batch_matmul_i8_i8_i32
|
||||
// CHECK: linalg.batch_matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
||||
|
||||
@ -1138,7 +1160,7 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
|
||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.batch_matmul_i16_i16_i32
|
||||
// CHECK: linalg.batch_matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
||||
|
||||
@ -1444,8 +1466,8 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf3
|
||||
|
||||
// -----
|
||||
|
||||
func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
|
||||
-> tensor<?x?x?xf32> {
|
||||
func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x8x?xf32>, %arg1: tensor<2x?x?xf32>)
|
||||
-> tensor<?x7x?xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
@ -1463,31 +1485,29 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tenso
|
||||
padding = dense<[[0], [0]]> : tensor<2x1xi64>,
|
||||
rhs_dilation = dense<1> : tensor<1xi64>,
|
||||
window_strides = dense<1> : tensor<1xi64>
|
||||
} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
} : (tensor<?x8x?xf32>, tensor<2x?x?xf32>) -> tensor<?x7x?xf32>
|
||||
return %0 : tensor<?x7x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
|
||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]]
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
// CHECK: linalg.conv_1d_input_nwc_filter_wcf
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<1xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x8x?xf32>, tensor<2x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x7x?xf32>) -> tensor<?x7x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>)
|
||||
-> tensor<?x?x?x?xf32> {
|
||||
func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3x2x?x?xf32>)
|
||||
-> tensor<?x2x3x?xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
@ -1505,33 +1525,29 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?
|
||||
padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
} : (tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) -> tensor<?x2x3x?xf32>
|
||||
return %0 : tensor<?x2x3x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
|
||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
|
||||
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>)
|
||||
-> tensor<?x?x?x?x?xf32> {
|
||||
func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x8x8x8x?xf32>, %arg1: tensor<2x2x2x?x?xf32>)
|
||||
-> tensor<?x7x7x7x?xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
@ -1549,30 +1565,24 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tens
|
||||
padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
|
||||
rhs_dilation = dense<1> : tensor<3xi64>,
|
||||
window_strides = dense<1> : tensor<3xi64>
|
||||
} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?xf32>
|
||||
} : (tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>) -> tensor<?x7x7x7x?xf32>
|
||||
return %0 : tensor<?x7x7x7x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
|
||||
// CHECK: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]]
|
||||
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<3xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x7x7x7x?xf32>) -> tensor<?x7x7x7x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -199,20 +199,24 @@ func @addUnrankedUnranked(
|
||||
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// Handle rank 1 specialization
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
||||
@ -222,12 +226,12 @@ func @addUnrankedUnranked(
|
||||
// Handle rank 2 specialization
|
||||
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
||||
@ -237,12 +241,12 @@ func @addUnrankedUnranked(
|
||||
// Handle rank 3 specialization
|
||||
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
||||
@ -252,47 +256,30 @@ func @addUnrankedUnranked(
|
||||
// Handle rank 4 specialization
|
||||
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index
|
||||
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]]
|
||||
// Handle rank 5 specialization
|
||||
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C6]] : index
|
||||
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_6]]
|
||||
// Handle rank 6 specialization
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
@ -300,7 +287,8 @@ func @addUnrankedUnranked(
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
@ -325,13 +313,18 @@ func @selectUnrankedUnrankedUnranked(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[PRED_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %c1 = constant 1 : index
|
||||
@ -339,15 +332,15 @@ func @selectUnrankedUnrankedUnranked(
|
||||
// Handle rank 1 specialization
|
||||
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
||||
@ -357,4 +350,3 @@ func @selectUnrankedUnrankedUnranked(
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
|
@ -368,6 +368,16 @@ func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @concat_1D
|
||||
// Verifies that an error is not thrown if the inferred type is compatible with
|
||||
// the result type.
|
||||
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> {
|
||||
// expected-error@+1 {{'mhlo.concatenate' op requires the same element type for all operands and results}}
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
|
||||
@ -384,14 +394,6 @@ func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<
|
||||
|
||||
// -----
|
||||
|
||||
func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
|
||||
// expected-error@+1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<3xi32>'}}
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> {
|
||||
// expected-error@+1 {{op inferred type(s) 'tensor<3xi32>' are incompatible with return type(s) of operation 'tensor<4xi32>'}}
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
|
||||
|
@ -627,6 +627,7 @@ tf_native_cc_binary(
|
||||
srcs = [
|
||||
"converter_gen.cc",
|
||||
],
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
|
@ -1192,11 +1192,11 @@ void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
|
||||
auto cond = symbol_table.lookup<mlir::FuncOp>(
|
||||
while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
|
||||
AddCallOpInWhileOpRegion(while_op.cond(), cond);
|
||||
while_op.removeAttr("cond");
|
||||
while_op->removeAttr("cond");
|
||||
auto body = symbol_table.lookup<mlir::FuncOp>(
|
||||
while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
|
||||
AddCallOpInWhileOpRegion(while_op.body(), body);
|
||||
while_op.removeAttr("body");
|
||||
while_op->removeAttr("body");
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
@ -1011,9 +1011,7 @@ def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [
|
||||
TFL_OperandHasAtleastRank<0, 2>,
|
||||
TFL_OperandHasAtleastRank<1, 2>,
|
||||
PredOpTrait<"x and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
PredOpTrait<"y and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
|
||||
let summary = "Batch Matrix Multiply Operator";
|
||||
|
||||
@ -2774,7 +2772,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
|
||||
def TFL_SelectOp : TFL_Op<"select", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
|
||||
PredOpTrait<"operands have same element type", TFL_TCopVTEtAreSameAt<1, 2>>,
|
||||
PredOpTrait<"operands and result have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
let summary = "Select operator";
|
||||
@ -2810,8 +2808,9 @@ def TFL_SelectOp : TFL_Op<"select", [
|
||||
def TFL_SelectV2Op : TFL_Op<"select_v2", [
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 4>,
|
||||
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
|
||||
PredOpTrait<"operands have same element type", TFL_TCopVTEtAreSameAt<1, 2>>,
|
||||
PredOpTrait<"operands and result have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
let summary = "SelectV2 operator";
|
||||
|
@ -14,15 +14,16 @@ func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x120x120x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32> {
|
||||
func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x120x120x8xf32>
|
||||
%cst_1 = constant dense<0> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x64x64x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x64x64x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x60x60x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x60x60x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x120x120x8xf32>
|
||||
return %2 : tensor<1x120x120x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithNonConstantPadAndCrops
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x120x120x8xf32>
|
||||
}
|
||||
@ -73,37 +74,39 @@ func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<4> : tensor<2x2xi32>
|
||||
%cst_1 = constant dense<0> : tensor<2x2xi32>
|
||||
%cst_2 = constant dense<0> : tensor<4x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.Pad"(%1, %cst_2) : (tensor<4x64x64x8xf32>, tensor<4x2xi32>) -> tensor<4x64x64x8xf32>
|
||||
%3 = "tf.BatchToSpaceND"(%2, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
return %4 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithPad
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<4> : tensor<2x2xi32>
|
||||
%cst_1 = constant dense<0> : tensor<2x2xi32>
|
||||
%cst_2 = constant dense<0> : tensor<4x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
|
||||
%1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.Pad"(%1, %cst_2) : (tensor<4x64x64x8xf32>, tensor<4x2xi32>) -> tensor<4x64x64x8xf32>
|
||||
%3 = "tf.BatchToSpaceND"(%2, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
return %4 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithPad
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
@ -235,22 +238,23 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%cst_1 = constant dense<4> : tensor<2x2xi32>
|
||||
%cst_2 = constant dense<0> : tensor<2x2xi32>
|
||||
%cst_3 = constant dense<0> : tensor<3x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.Pad"(%3, %cst_3) : (tensor<4x64x64xf32>, tensor<3x2xi32>) -> tensor<4x64x64xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
%6 = "tf.BiasAdd"(%5, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %6 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
@ -259,22 +263,23 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%cst_1 = constant dense<4> : tensor<2x2xi32>
|
||||
%cst_2 = constant dense<0> : tensor<2x2xi32>
|
||||
%cst_3 = constant dense<0> : tensor<3x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.Pad"(%3, %cst_3) : (tensor<4x64x64xf32>, tensor<3x2xi32>) -> tensor<4x64x64xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
%6 = "tf.BiasAdd"(%5, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %6 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
@ -481,3 +486,27 @@ func @testDilatedConv1DWithMixedPostiveAndNegativeAxis(%arg0: tensor<1x128x3xf32
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||
}
|
||||
|
||||
func @testPaddedDilatedConv(%arg0 : tensor<2x1920x64xf32>) -> tensor<2x1920x128xf32> {
|
||||
%0 = "tf.Const"() {value = dense<[[0, 0], [2, 0], [0, 0]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
%1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
%3 = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%4 = "tf.Const"() {value = dense<0.0> : tensor<3x1x64x128xf32>} : () -> tensor<3x1x64x128xf32>
|
||||
%5 = "tf.SpaceToBatchND"(%arg0, %1, %3) {device = ""} : (tensor<2x1920x64xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<4x960x64xf32>
|
||||
%6 = "tf.ExpandDims"(%5, %2) {device = ""} : (tensor<4x960x64xf32>, tensor<i32>) -> tensor<4x960x1x64xf32>
|
||||
%7 = "tf.Conv2D"(%6, %4) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<4x960x1x64xf32>, tensor<3x1x64x128xf32>) -> tensor<4x958x1x128xf32>
|
||||
%8 = "tf.Squeeze"(%7) {device = "", squeeze_dims = [2]} : (tensor<4x958x1x128xf32>) -> tensor<4x958x128xf32>
|
||||
%9 = "tf.Pad"(%8, %0) {device = ""} : (tensor<4x958x128xf32>, tensor<3x2xi32>) -> tensor<4x960x128xf32>
|
||||
%10 = "tf.BatchToSpaceND"(%9, %1, %3) {device = ""} : (tensor<4x960x128xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x1920x128xf32>
|
||||
return %10 : tensor<2x1920x128xf32>
|
||||
|
||||
// CHECK-LABEL: testPaddedDilatedConv
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<2x1920x64xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[FILTER:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x1x64x128xf32>} : () -> tensor<3x1x64x128xf32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) {device = ""} : (tensor<2x1920x64xf32>, tensor<i32>) -> tensor<2x1920x1x64xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {data_format = "NHWC", device = "", dilations = [1, 2, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<2x1920x1x64xf32>, tensor<3x1x64x128xf32>) -> tensor<2x1920x1x128xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {device = "", squeeze_dims = [2]} : (tensor<2x1920x1x128xf32>) -> tensor<2x1920x128xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<2x1920x128xf32>
|
||||
}
|
||||
|
@ -1387,6 +1387,15 @@ func @testBatchMatmulQuant(%arg0 : tensor<1x4x384x32x!quant.uniform<i8:f32, 0.06
|
||||
%0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32x!quant.uniform<i8:f32, 0.06:-2>>, tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>>
|
||||
return %0 : tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatmulHybridQuant(%arg0 : tensor<1x4x384x32xf32>, %arg1 : tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384xf32> {
|
||||
// CHECK: "tfl.batch_matmul"(%arg0, %arg1)
|
||||
%0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384xf32>
|
||||
return %0 : tensor<1x4x384x384xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConcat(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<2x2xi32> {
|
||||
|
@ -564,6 +564,78 @@ func @FuseFullyConnectedReshapeAddConstWithActivation(%arg0: tensor<40x37xf32>,
|
||||
// FOLD: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConst
|
||||
func @FuseFullyConnectedReshapeAdd2DConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<4x10xf32>
|
||||
%shape = constant dense<[1, 40, 4, 10]> : tensor<4xi32>
|
||||
|
||||
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32>
|
||||
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x4x10xf32>, tensor<4x10xf32>) -> tensor<1x40x4x10xf32>
|
||||
|
||||
return %2 : tensor<1x40x4x10xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<2.000000e+00> : tensor<40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]]
|
||||
// CHECK: return %[[rs]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConstWithActivation
|
||||
func @FuseFullyConnectedReshapeAdd2DConstWithActivation(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<4x10xf32>
|
||||
%shape = constant dense<[1, 40, 4, 10]> : tensor<4xi32>
|
||||
|
||||
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32>
|
||||
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x40x4x10xf32>, tensor<4x10xf32>) -> tensor<1x40x4x10xf32>
|
||||
|
||||
return %2 : tensor<1x40x4x10xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<2.000000e+00> : tensor<40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]]
|
||||
// CHECK: return %[[rs]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConstWithExistingBias
|
||||
func @FuseFullyConnectedReshapeAdd2DConstWithExistingBias(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> {
|
||||
%cst = constant dense<3.0> : tensor<40xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<4x10xf32>
|
||||
%shape = constant dense<[1, 40, 4, 10]> : tensor<4xi32>
|
||||
|
||||
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40xf32>) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32>
|
||||
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x4x10xf32>, tensor<4x10xf32>) -> tensor<1x40x4x10xf32>
|
||||
|
||||
return %2 : tensor<1x40x4x10xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]]
|
||||
// CHECK: return %[[rs]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotFuseFullyConnectedReshapeAdd2DConstIfLastDimIsNotNumElementsOfRhs
|
||||
func @NotFuseFullyConnectedReshapeAdd2DConstIfLastDimIsNotNumElementsOfRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<20x37xf32>) -> tensor<1x20x4x10xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<4x10xf32>
|
||||
%shape = constant dense<[1, 20, 4, 10]> : tensor<4xi32>
|
||||
|
||||
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<20x37xf32>, none) -> (tensor<40x20xf32>)
|
||||
%1 = "tfl.reshape"(%0, %shape) : (tensor<40x20xf32>, tensor<4xi32>) -> tensor<1x20x4x10xf32>
|
||||
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x20x4x10xf32>, tensor<4x10xf32>) -> tensor<1x20x4x10xf32>
|
||||
|
||||
return %2 : tensor<1x20x4x10xf32>
|
||||
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1
|
||||
// CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]]
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[rs]]
|
||||
// CHECK: return %[[add]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastableAfter
|
||||
func @NotReorderReshapeAddIfNotBroadcastableAfter(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<2.0> : tensor<40xf32>
|
||||
@ -642,6 +714,19 @@ func @NotReorderReshapeAddIfHighDim(%arg0: tensor<1x1x1x1x30x96xf32>) -> tensor<
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAdd2DConstIfInputIsNotDefinedByFullyConnected
|
||||
func @NotReorderReshapeAdd2DConstIfInputIsNotDefinedByFullyConnected(%arg0: tensor<8x15xf32>) -> tensor<1x8x3x5xf32> {
|
||||
%cst = constant dense<2.0> : tensor<3x5xf32>
|
||||
%shape = constant dense<[1, 8, 3, 5]> : tensor<4xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<8x15xf32>, tensor<4xi32>) -> tensor<1x8x3x5xf32>
|
||||
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x8x3x5xf32>, tensor<3x5xf32>) -> tensor<1x8x3x5xf32>
|
||||
return %2 : tensor<1x8x3x5xf32>
|
||||
|
||||
// CHECK: %[[rs:.*]] = "tfl.reshape"(%arg0
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[rs]]
|
||||
// CHECK: return %[[add]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp
|
||||
func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
||||
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||
@ -1679,8 +1764,17 @@ func @ReorderReshapex2Add(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x3x4xf32>
|
||||
// CHECK: return %[[VAL_1]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ConvertSliceToIdentity
|
||||
func @ConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> {
|
||||
// CHECK-LABEL: ConvertSliceToIdentityI32
|
||||
func @ConvertSliceToIdentityI32(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> {
|
||||
%begin = constant dense<0> : tensor<4xi32>
|
||||
%shape = constant dense<[2,3,4,5]> : tensor<4xi32>
|
||||
%0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<2x3x4x5xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<2x3x4x5xf32>
|
||||
return %0 : tensor<2x3x4x5xf32>
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ConvertSliceToIdentityI64
|
||||
func @ConvertSliceToIdentityI64(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> {
|
||||
%begin = constant dense<0> : tensor<4xi64>
|
||||
%shape = constant dense<[2,3,4,5]> : tensor<4xi64>
|
||||
%0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<2x3x4x5xf32>, tensor<4xi64>, tensor<4xi64>) -> tensor<2x3x4x5xf32>
|
||||
@ -1688,6 +1782,24 @@ func @ConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ConvertSliceToIdentityStaticDimWithShapeWithNeg1
|
||||
func @ConvertSliceToIdentityStaticDimWithShapeWithNeg1(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> {
|
||||
%begin = constant dense<0> : tensor<4xi32>
|
||||
%shape = constant dense<[-1, 3, -1, 5]> : tensor<4xi32>
|
||||
%0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<2x3x4x5xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<2x3x4x5xf32>
|
||||
return %0 : tensor<2x3x4x5xf32>
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ConvertSliceToIdentityDynamicDimAndShapeWithNeg1
|
||||
func @ConvertSliceToIdentityDynamicDimAndShapeWithNeg1(%arg0: tensor<?x3x?x5xf32>) -> tensor<?x3x?x5xf32> {
|
||||
%begin = constant dense<0> : tensor<4xi32>
|
||||
%shape = constant dense<[-1, 3, -1, 5]> : tensor<4xi32>
|
||||
%0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<?x3x?x5xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<?x3x?x5xf32>
|
||||
return %0 : tensor<?x3x?x5xf32>
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DontConvertSliceToIdentity
|
||||
func @DontConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> (tensor<2x3x4x4xf32>, tensor<1x2x3x4xf32>) {
|
||||
%begin0 = constant dense<0> : tensor<4xi64>
|
||||
@ -1706,6 +1818,28 @@ func @DontConvertSliceToIdentity(%arg0: tensor<2x3x4x5xf32>) -> (tensor<2x3x4x4x
|
||||
// CHECK: return %[[SLICE_0]], %[[SLICE_1]] : tensor<2x3x4x4xf32>, tensor<1x2x3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DontConvertSliceToIdentityNonConstShape
|
||||
func @DontConvertSliceToIdentityNonConstShape(%arg0: tensor<?xf32>, %arg1: tensor<1xi32>) -> tensor<?xf32> {
|
||||
%begin = constant dense<0> : tensor<1xi32>
|
||||
%0 = "tfl.slice"(%arg0, %begin, %arg1) : (tensor<?xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
// CHECK: %[[BEGIN:.*]] = constant dense<0> : tensor<1xi32>
|
||||
// CHECK: %[[SLICE:.*]] = "tfl.slice"(%arg0, %[[BEGIN]], %arg1) : (tensor<?xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
// CHECK: return %[[SLICE]] : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DontConvertSliceToIdentityDynamicDimButEqualShape
|
||||
func @DontConvertSliceToIdentityDynamicDimButEqualShape(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%begin = constant dense<0> : tensor<1xi32>
|
||||
%shape = constant dense<2> : tensor<1xi32>
|
||||
%0 = "tfl.slice"(%arg0, %begin, %shape) : (tensor<?xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
// CHECK: %[[BEGIN:.*]] = constant dense<0> : tensor<1xi32>
|
||||
// CHECK: %[[SHAPE:.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: %[[SLICE:.*]] = "tfl.slice"(%arg0, %[[BEGIN]], %[[SHAPE]]) : (tensor<?xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
// CHECK: return %[[SLICE]] : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseAddWithFullyConnectedWithBias
|
||||
func @FuseAddWithFullyConnectedWithBias(%arg: tensor<2x512xf32>) -> tensor<2x1024xf32> {
|
||||
%cst_add = constant dense<2.0> : tensor<512xf32>
|
||||
|
@ -22,11 +22,13 @@ limitations under the License.
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
@ -82,35 +84,77 @@ template <typename Conv2dOpTy>
|
||||
LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
Conv2dOpTy op, PatternRewriter& rewriter) const {
|
||||
if (!op.getResult().hasOneUse()) {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "result for current op has more than 1 use");
|
||||
}
|
||||
// Make sure Conv2D has 'VALID' padding.
|
||||
if (op->template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Conv2D op doesn't have valid padding");
|
||||
}
|
||||
// Make sure dilations are all ones if set.
|
||||
const ArrayAttr& dilations =
|
||||
op->template getAttrOfType<ArrayAttr>("dilations");
|
||||
if (dilations && !TFIntListIsAllOnes(dilations)) {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(op, "dilations should be all 1");
|
||||
}
|
||||
|
||||
if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op))
|
||||
return failure();
|
||||
if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "op's input is not float or the data format isn't NHWC");
|
||||
}
|
||||
|
||||
// Allow dynamic width and height dimensions only.
|
||||
auto result_ty = op.getResult().getType().template cast<TensorType>();
|
||||
if (!result_ty.hasRank() || result_ty.getRank() != 4 ||
|
||||
result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3))
|
||||
return failure();
|
||||
result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only dynamic width and height dimensions are allowed");
|
||||
}
|
||||
|
||||
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
|
||||
// `Squeeze` op.
|
||||
Operation* prev_op = op.getOperation()->getPrevNode();
|
||||
if (!prev_op) return failure();
|
||||
if (!prev_op || prev_op->getNumResults() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "op doesn't have a preceding node that has a single result");
|
||||
}
|
||||
if (!prev_op->hasOneUse() || *(prev_op->getResult(0).user_begin()) != op) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "op's input isn't produced by previous operation");
|
||||
}
|
||||
|
||||
Operation* next_op = op.getOperation()->getNextNode();
|
||||
if (!next_op) return failure();
|
||||
auto tryGetNextNode =
|
||||
[&rewriter](Operation* current) -> std::pair<LogicalResult, Operation*> {
|
||||
// Check the current operation has a single result.
|
||||
if (current->getNumResults() != 1) {
|
||||
return {
|
||||
rewriter.notifyMatchFailure(current, "op doesn't have single result"),
|
||||
nullptr};
|
||||
}
|
||||
// Check the current operation has a next node.
|
||||
Operation* next_op = current->getNextNode();
|
||||
if (!next_op) {
|
||||
return {rewriter.notifyMatchFailure(current, "op doesn't have next node"),
|
||||
nullptr};
|
||||
}
|
||||
// Check the current operation's result is used by its successor node.
|
||||
if (!current->hasOneUse() ||
|
||||
*(current->getResult(0).user_begin()) != next_op) {
|
||||
return {
|
||||
rewriter.notifyMatchFailure(
|
||||
current, "op's result isn't directly consumed by the next op"),
|
||||
nullptr};
|
||||
}
|
||||
return {LogicalResult::success(), next_op};
|
||||
};
|
||||
|
||||
std::pair<LogicalResult, Operation*> maybeNextNode =
|
||||
tryGetNextNode(op.getOperation());
|
||||
if (failed(maybeNextNode.first)) {
|
||||
return maybeNextNode.first;
|
||||
}
|
||||
Operation* next_op = maybeNextNode.second;
|
||||
|
||||
TF::ExpandDimsOp expand_op;
|
||||
TF::SqueezeOp squeeze_op;
|
||||
@ -119,15 +163,19 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
|
||||
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
|
||||
// Expand/Squeeze op must come in pair.
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "ExpandDimsOp and SqueezeOp should come in pair");
|
||||
}
|
||||
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
|
||||
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
|
||||
if (!expand_op.getResult().hasOneUse() ||
|
||||
!squeeze_op.getResult().hasOneUse()) {
|
||||
return failure();
|
||||
if (!expand_op.getResult().hasOneUse()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
expand_op, "result for current op has more than 1 use");
|
||||
}
|
||||
if (!squeeze_op.getResult().hasOneUse()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
squeeze_op, "result for current op has more than 1 use");
|
||||
}
|
||||
|
||||
// Make sure that the axis in `expand_op` is constant.
|
||||
if (auto const_op =
|
||||
llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
|
||||
@ -141,12 +189,14 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
expand_axis += 4;
|
||||
}
|
||||
} else {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
expand_op, "ExpandDimsOp doesn't have a constant axis");
|
||||
}
|
||||
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
|
||||
auto squeeze_dims = squeeze_op.squeeze_dims();
|
||||
if (squeeze_dims.size() != 1) {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
squeeze_op, "squeeze dims should have exactly 1 dimension specified");
|
||||
}
|
||||
int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
|
||||
if (squeeze_axis < 0) {
|
||||
@ -154,36 +204,62 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
squeeze_axis += 4;
|
||||
}
|
||||
if (squeeze_axis != expand_axis) {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "squeeze axis and expand axis doesn't match");
|
||||
}
|
||||
|
||||
// Update previous/next op pointer.
|
||||
prev_op = prev_op->getPrevNode();
|
||||
if (!prev_op) return failure();
|
||||
next_op = next_op->getNextNode();
|
||||
if (!next_op) return failure();
|
||||
Operation* tmp = prev_op->getPrevNode();
|
||||
if (!tmp || tmp->getNumResults() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
prev_op, "op doesn't have a preceding node that has a single result");
|
||||
}
|
||||
if (!tmp->hasOneUse() || *(tmp->getResult(0).user_begin()) != prev_op) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
prev_op, "op's input isn't defined by its previous node");
|
||||
}
|
||||
prev_op = tmp;
|
||||
std::pair<LogicalResult, Operation*> maybeNextNode =
|
||||
tryGetNextNode(next_op);
|
||||
if (failed(maybeNextNode.first)) {
|
||||
return maybeNextNode.first;
|
||||
}
|
||||
next_op = maybeNextNode.second;
|
||||
}
|
||||
|
||||
// SpaceToBatchND op.
|
||||
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return failure();
|
||||
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) {
|
||||
return rewriter.notifyMatchFailure(prev_op,
|
||||
"op should be a SpaceToBatchND op");
|
||||
}
|
||||
// TODO(b/149936532): Check `padding` input, currently ignored.
|
||||
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
|
||||
if (!stb_op.getResult().hasOneUse()) {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
stb_op, "result for current op has more than 1 use");
|
||||
}
|
||||
|
||||
// Pad op.
|
||||
TF::PadOp pad_op;
|
||||
// TODO(b/149936532): Currently we just ignore the PadOp. However note that
|
||||
// in real scenarios this may not always be correct: user can put a PadOp here
|
||||
// with non-trivial consequences.
|
||||
ElementsAttr pad_attr;
|
||||
if (llvm::isa<TF::PadOp>(next_op)) {
|
||||
pad_op = llvm::cast<TF::PadOp>(next_op);
|
||||
if (!pad_op.getResult().hasOneUse()) {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
pad_op, "result for current op has more than 1 use");
|
||||
}
|
||||
std::pair<LogicalResult, Operation*> maybeNextNode =
|
||||
tryGetNextNode(next_op);
|
||||
if (failed(maybeNextNode.first)) {
|
||||
return maybeNextNode.first;
|
||||
}
|
||||
next_op = maybeNextNode.second;
|
||||
if (!matchPattern(pad_op.paddings(), m_Constant(&pad_attr))) {
|
||||
// If the padding value isn't constant, we can't determine the padding
|
||||
// scheme for Conv2D below, in this case just reject the pattern.
|
||||
return rewriter.notifyMatchFailure(
|
||||
pad_op, "PadOp's padding value isn't constant");
|
||||
}
|
||||
next_op = next_op->getNextNode();
|
||||
if (!next_op) return failure();
|
||||
}
|
||||
|
||||
// BatchToSpaceND + BiasAdd.
|
||||
@ -194,33 +270,53 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
// Must be BiasAdd + BatchToSpaceND.
|
||||
biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
|
||||
if (!biasadd_op.getResult().hasOneUse()) {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
biasadd_op, "result for current op has more than 1 use");
|
||||
}
|
||||
next_op = next_op->getNextNode();
|
||||
if (!next_op || !llvm::isa<TF::BatchToSpaceNDOp>(next_op)) return failure();
|
||||
std::pair<LogicalResult, Operation*> maybeNextNode =
|
||||
tryGetNextNode(next_op);
|
||||
if (failed(maybeNextNode.first)) {
|
||||
return maybeNextNode.first;
|
||||
}
|
||||
if (!llvm::isa<TF::BatchToSpaceNDOp>(maybeNextNode.second)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
next_op, "op's next node isn't BatchToSpaceND op");
|
||||
}
|
||||
next_op = maybeNextNode.second;
|
||||
bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
|
||||
} else if (llvm::isa<TF::BatchToSpaceNDOp>(next_op)) {
|
||||
// BatchToSpaceND + (optional) BiasAdd.
|
||||
bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
|
||||
next_op = next_op->getNextNode();
|
||||
if (next_op && llvm::isa<TF::BiasAddOp>(next_op)) {
|
||||
Operation* tmp = next_op->getNextNode();
|
||||
if (tmp && llvm::isa<TF::BiasAddOp>(tmp)) {
|
||||
if (!bts_op.getResult().hasOneUse()) {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
bts_op, "result for current op has more than 1 use");
|
||||
}
|
||||
if (!next_op->hasOneUse() ||
|
||||
*(next_op->getResult(0).user_begin()) != tmp) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
next_op, "op's result isn't directly consumed by the next op");
|
||||
}
|
||||
next_op = tmp;
|
||||
biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
|
||||
final_op_is_bts = false;
|
||||
}
|
||||
} else {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
next_op, "next op is neither BiasAdd nor BatchToSpaceND");
|
||||
}
|
||||
|
||||
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
|
||||
stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
|
||||
if (!dilations_attr.hasValue()) return failure();
|
||||
if (!dilations_attr.hasValue()) {
|
||||
return rewriter.notifyMatchFailure(op, "failed to extract dilation rate");
|
||||
}
|
||||
|
||||
if (expand_op) {
|
||||
if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
stb_op, "SpaceToBatchND op's input should have RankedTensorType");
|
||||
}
|
||||
}
|
||||
|
||||
@ -255,16 +351,33 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
auto stb_paddings = stb_op.paddings();
|
||||
auto bts_crops = bts_op.crops();
|
||||
ElementsAttr stb_paddings_attr, bts_crops_attr;
|
||||
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) &&
|
||||
matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
|
||||
if (stb_paddings_attr.getNumElements() != bts_crops_attr.getNumElements())
|
||||
return failure();
|
||||
// padding - crop.
|
||||
auto paddings = stb_paddings_attr.getValues<IntegerAttr>();
|
||||
auto crops = bts_crops_attr.getValues<IntegerAttr>();
|
||||
for (auto it1 = paddings.begin(), it2 = crops.begin();
|
||||
it1 != paddings.end() && it2 != crops.end(); it1++, it2++) {
|
||||
if ((*it1).getInt() != (*it2).getInt()) {
|
||||
if (!matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) ||
|
||||
!matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"either SpaceToBatchND or BatchToSpaceND "
|
||||
"doesn't have constant padding/crops value");
|
||||
}
|
||||
if (stb_paddings_attr.getType() != bts_crops_attr.getType()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
stb_op,
|
||||
"SpaceToBatchND op's padding doesn't have same shape/type with "
|
||||
"BatchToSpaceND op's crops");
|
||||
}
|
||||
int64_t m = stb_paddings_attr.getType().getDimSize(0);
|
||||
// padding - crop.
|
||||
for (uint64_t i = 0; i < m; ++i) {
|
||||
for (uint64_t j = 0; j < 2; ++j) {
|
||||
// `crops` tensor has shape [M, 2], crops[i] = [crop_start, crop_end]
|
||||
// specifies the amount to crop from input dimension i + 1. If the input
|
||||
// of `BatchToSpaceND` has been padded explicitly, then we need to
|
||||
// take into account the additional padding when determining the padding
|
||||
// scheme for `Conv2D`.
|
||||
int64_t addtional_pad =
|
||||
pad_attr ? pad_attr.getValue<IntegerAttr>({i + 1, j}).getInt() : 0;
|
||||
if (stb_paddings_attr.getValue<IntegerAttr>({i, j}).getInt() +
|
||||
addtional_pad !=
|
||||
bts_crops_attr.getValue<IntegerAttr>({i, j}).getInt()) {
|
||||
op->setAttr("padding", rewriter.getStringAttr("SAME"));
|
||||
break;
|
||||
}
|
||||
@ -316,7 +429,11 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
}
|
||||
|
||||
if (final_op_is_bts) {
|
||||
bts_op.getResult().replaceAllUsesWith(bts_op.input());
|
||||
if (bts_op.input().getDefiningOp<TF::PadOp>()) {
|
||||
bts_op.getResult().replaceAllUsesWith(pad_op.input());
|
||||
} else {
|
||||
bts_op.getResult().replaceAllUsesWith(bts_op.input());
|
||||
}
|
||||
}
|
||||
|
||||
stb_op.getResult().dropAllUses();
|
||||
|
@ -816,7 +816,7 @@ struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input = operands[0];
|
||||
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
|
||||
op.getAttrs());
|
||||
op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -948,7 +948,7 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
|
||||
|
||||
// Create a new while op with new operands and updated result types.
|
||||
auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
|
||||
operands, op.getAttrs());
|
||||
operands, op->getAttrs());
|
||||
converted.removeAttr("T");
|
||||
(void)UpdateFunctionTypes(rewriter, converted, tensor_list_args);
|
||||
|
||||
@ -972,7 +972,7 @@ struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
|
||||
|
||||
// Create a new while op with new operands and updated result types.
|
||||
auto converted = rewriter.create<TF::WhileRegionOp>(
|
||||
op.getLoc(), result_types, operands, op.getAttrs());
|
||||
op.getLoc(), result_types, operands, op->getAttrs());
|
||||
|
||||
// Inline the regions from the old while into the new one, and apply
|
||||
// signature conversion to inlined region.
|
||||
|
@ -210,6 +210,50 @@ bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if we can eliminate the SliceOp. When the values of `begin` are
|
||||
// all 0s and `size[i]` is equal to either -1 or `input.shape[i]`
|
||||
// for each dim i, the output tensor is identical to `input`.
|
||||
bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) {
|
||||
// Checks if `begin` and `size` are i32 or i64.
|
||||
auto begin_attr = begin.dyn_cast<DenseIntElementsAttr>();
|
||||
auto size_attr = size.dyn_cast<DenseIntElementsAttr>();
|
||||
if (!begin_attr || !size_attr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto begin_elem_ty = begin_attr.getType().getElementType();
|
||||
if (!begin_elem_ty.isInteger(32) && !begin_elem_ty.isInteger(64)) {
|
||||
return false;
|
||||
}
|
||||
auto size_elem_ty = size_attr.getType().getElementType();
|
||||
if (!size_elem_ty.isInteger(32) && !size_elem_ty.isInteger(64)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Checks if `input` is ranked and its rank is equal to number of elements in
|
||||
// `begin` and `size`.
|
||||
auto input_ty = input.getType().cast<ShapedType>();
|
||||
if (!input_ty.hasRank()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t rank = input_ty.getRank();
|
||||
if (rank != begin_attr.getNumElements() ||
|
||||
rank != size_attr.getNumElements()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Checks if `begin` is all 0s, and `size[i]` is equal to either -1 or
|
||||
// `input.shape[i]`.
|
||||
for (uint64_t i = 0; i < rank; ++i) {
|
||||
if (begin_attr.getValue<APInt>({i}).getSExtValue() != 0) return false;
|
||||
int64_t si = size_attr.getValue<APInt>({i}).getSExtValue();
|
||||
if (si != -1 && si != input_ty.getDimSize(i)) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Expand Attribute 'a' to 4D with all 1s except 1 dimension.
|
||||
// Which dimension depends on 'is_depthwise' is true or false.
|
||||
ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
|
||||
|
@ -367,6 +367,19 @@ def OperandsBroadcastToOutputType : Constraint<CPred<
|
||||
def IsTailOfShape : Constraint<CPred<
|
||||
"TFL::IsTailOfShape($0.getType(), $1.getType())">>;
|
||||
|
||||
def Flatten : NativeCodeCall<
|
||||
"$0.cast<DenseElementsAttr>()"
|
||||
".reshape(RankedTensorType::get({$0.getType().cast<ShapedType>().getNumElements()}, "
|
||||
"$0.getType().cast<ShapedType>().getElementType()))">;
|
||||
|
||||
def IsLastDimEqualToNumElements : Constraint<CPred<
|
||||
"$0.getType().cast<ShapedType>().getRank() >= 1 && "
|
||||
"$0.getType().cast<ShapedType>().getDimSize($0.getType().cast<ShapedType>().getRank() - 1) == "
|
||||
"$1.getType().cast<ShapedType>().getNumElements()">>;
|
||||
|
||||
def IsDefinedByFullyConnectedOp : Constraint<CPred<
|
||||
"$0.getDefiningOp<TFL::FullyConnectedOp>() != nullptr">>;
|
||||
|
||||
// Pattern for skipping Tile if it is mainly for broadcasting and the
|
||||
// Op is already supporting broadcasting.
|
||||
multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
|
||||
@ -446,6 +459,29 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
|
||||
(IsTailOfShape $lhs, $rhs),
|
||||
(IsTailOfShape $input1, $input2),
|
||||
(IsTailOfShape $input2, $input1)]>;
|
||||
|
||||
// Move binary op before reshape:
|
||||
// binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs)))
|
||||
// This is valid only when the last dimension of lhs is equal to the
|
||||
// number of elements in constant rhs.
|
||||
// Therefore, after transformation broadcast of binary op is always
|
||||
// applied to the last dimension of $input.
|
||||
def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs ElementsAttr:$rhs_attr), $act_fn),
|
||||
(TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr)),
|
||||
$act_fn),
|
||||
$shape),
|
||||
[(AnyStaticShapeTensor $input),
|
||||
(IsTailOfShape $rhs, $lhs),
|
||||
(IsLastDimEqualToNumElements $input, $rhs),
|
||||
(HasOneUse $lhs),
|
||||
// Restrict operands to have at most rank 4 because TFLite binary
|
||||
// kernel supports up to 4D broadcast.
|
||||
(HasRankAtMost<4> $input),
|
||||
(HasRankAtMost<4> $lhs),
|
||||
(HasRankAtMost<4> $rhs),
|
||||
(IsDefinedByFullyConnectedOp $input)]>;
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
@ -491,6 +527,28 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
(IsTailOfShape $lhs, $rhs),
|
||||
(IsTailOfShape $input1, $input2),
|
||||
(IsTailOfShape $input2, $input1)]>;
|
||||
|
||||
// Move binary op before reshape:
|
||||
// binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs)))
|
||||
// This is valid only when the last dimension of lhs is equal to the
|
||||
// number of elements in constant rhs.
|
||||
// Therefore, after transformation broadcast of binary op is always
|
||||
// applied to the last dimension of $input.
|
||||
def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs ElementsAttr:$rhs_attr)),
|
||||
(TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr))),
|
||||
$shape),
|
||||
[(AnyStaticShapeTensor $input),
|
||||
(IsTailOfShape $rhs, $lhs),
|
||||
(IsLastDimEqualToNumElements $input, $rhs),
|
||||
(HasOneUse $lhs),
|
||||
// Restrict operands to have at most rank 4 because TFLite binary
|
||||
// kernel supports up to 4D broadcast.
|
||||
(HasRankAtMost<4> $input),
|
||||
(HasRankAtMost<4> $lhs),
|
||||
(HasRankAtMost<4> $rhs),
|
||||
(IsDefinedByFullyConnectedOp $input)]>;
|
||||
}
|
||||
|
||||
// Reorder the element-wise value operations and the element move operations,
|
||||
@ -760,17 +818,11 @@ foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in {
|
||||
(AxesIsLastDimension $axes, $logits)]>;
|
||||
}
|
||||
|
||||
class AllElementsAreI64<string val> : Constraint<CPred<
|
||||
"($0.isa<DenseElementsAttr>() && "
|
||||
"$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isInteger(64) && "
|
||||
"std::all_of($0.cast<DenseElementsAttr>().getValues<int64_t>().begin(), "
|
||||
"$0.cast<DenseElementsAttr>().getValues<int64_t>().end(), "
|
||||
"[](int64_t v){ return v == " #val# ";}))">>;
|
||||
def CanOptimizeIdentitySliceOp : Constraint<CPred<
|
||||
"TFL::CanOptimizeIdentitySliceOp($0, $1, $2)">>;
|
||||
|
||||
// Remove Slice ops slicing the whole input tensor, effectively no-op
|
||||
// Remove Slice ops slicing the whole input tensor, effectively no-op.
|
||||
def OptimizeSliceOp : Pat<
|
||||
(TFL_SliceOp:$output $input, (ConstantOp $begin), $shape),
|
||||
(TFL_SliceOp:$output $input, (ConstantOp $begin), (ConstantOp $size)),
|
||||
(replaceWithValue $input),
|
||||
[(AllElementsAreI64<"0"> $begin),
|
||||
(IsTailOfShape $input, $output),
|
||||
(IsTailOfShape $output, $input)]>;
|
||||
[(CanOptimizeIdentitySliceOp $input, $begin, $size)]>;
|
||||
|
@ -254,7 +254,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
new_types.push_back(extra_operand.getType());
|
||||
|
||||
auto new_while_op = OpBuilder(while_op).create<WhileOp>(
|
||||
while_op.getLoc(), new_types, operands, while_op.getAttrs());
|
||||
while_op.getLoc(), new_types, operands, while_op->getAttrs());
|
||||
new_while_op.cond().takeBody(while_op.cond());
|
||||
new_while_op.body().takeBody(while_op.body());
|
||||
while_op.replaceAllUsesWith(
|
||||
|
@ -62,8 +62,7 @@ FuncOp createLstmCompositeFunc(mlir::Builder* builder, bool ln, bool cifg) {
|
||||
auto func_type = builder->getFunctionType(input_types, output_type);
|
||||
|
||||
auto func =
|
||||
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"),
|
||||
builder->getContext()),
|
||||
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func")),
|
||||
"fused_func", func_type, {});
|
||||
func.addEntryBlock();
|
||||
|
||||
|
@ -38,8 +38,7 @@ FuncOp createMaxUnpoolingFunc(
|
||||
const SmallVector<mlir::Type, NOutput>& output_types) {
|
||||
auto func_type = builder->getFunctionType(input_types, output_types);
|
||||
auto func =
|
||||
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"),
|
||||
builder->getContext()),
|
||||
FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func")),
|
||||
"fused_func", func_type, {});
|
||||
|
||||
func.addEntryBlock();
|
||||
|
@ -163,11 +163,11 @@ gentbl(
|
||||
tf_ops_category_list = [
|
||||
{
|
||||
"name": "ops_a_m",
|
||||
"include": "tf.[A-M].*$$",
|
||||
"include": "tf.[A-M].*$",
|
||||
},
|
||||
{
|
||||
"name": "ops_n_z",
|
||||
"include": "tf.[N-Z].*$$",
|
||||
"include": "tf.[N-Z].*$",
|
||||
},
|
||||
]
|
||||
|
||||
@ -177,11 +177,11 @@ tf_ops_category_list = [
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-decls -op-include-regex='" + target["include"] + "'",
|
||||
"-gen-op-decls -op-include-regex=" + target["include"],
|
||||
"ir/tf_" + target["name"] + ".h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-defs -op-include-regex='" + target["include"] + "'",
|
||||
"-gen-op-defs -op-include-regex=" + target["include"],
|
||||
"ir/tf_" + target["name"] + ".cc.inc",
|
||||
),
|
||||
],
|
||||
@ -198,11 +198,11 @@ gentbl(
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-decls -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ",
|
||||
"-gen-op-decls -op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]),
|
||||
"ir/tf_remaining_ops.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-defs -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ",
|
||||
"-gen-op-defs -op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]),
|
||||
"ir/tf_remaining_ops.cc.inc",
|
||||
),
|
||||
],
|
||||
@ -964,6 +964,7 @@ cc_library(
|
||||
"transforms/tpu_space_to_depth_pass.cc",
|
||||
"transforms/tpu_update_embedding_enqueue_op_inputs.cc",
|
||||
"transforms/tpu_variable_runtime_reformatting.cc",
|
||||
"transforms/verify_suitable_for_graph_export_pass.cc",
|
||||
"translate/breakup-islands.cc",
|
||||
"translate/tf_executor_to_functional.cc",
|
||||
"translate/tf_functional_to_executor.cc",
|
||||
@ -1010,6 +1011,7 @@ cc_library(
|
||||
":translate_utils",
|
||||
":unroll_batch_matmul_pass",
|
||||
":verification_utils",
|
||||
":verify_suitable_for_graph_export",
|
||||
":visitor_util",
|
||||
":xla_sharding_util",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
@ -1151,6 +1153,7 @@ cc_library(
|
||||
":tf_saved_model_passes",
|
||||
":translate_utils",
|
||||
":upgrade_graph",
|
||||
":verify_suitable_for_graph_export",
|
||||
"//tensorflow/cc/saved_model:bundle_v2",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
@ -2106,3 +2109,14 @@ cc_library(
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "verify_suitable_for_graph_export",
|
||||
srcs = ["utils/verify_suitable_for_graph_export.cc"],
|
||||
hdrs = ["utils/verify_suitable_for_graph_export.h"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
@ -226,8 +226,7 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) {
|
||||
if (isa<IdentityOp>(op)) return true;
|
||||
|
||||
// For op's in the Tensorflow dialect, query the dialect.
|
||||
if (op->getName().getDialect() ==
|
||||
TF::TensorFlowDialect::getDialectNamespace())
|
||||
if (isa_and_nonnull<TF::TensorFlowDialect>(op->getDialect()))
|
||||
return !TensorFlowDialect::CanHaveSideEffects(op);
|
||||
|
||||
// Otherwise, conservatively assume that there can be side effects.
|
||||
|
@ -413,7 +413,7 @@ void Print(ReplicateOp op, OpAsmPrinter* p) {
|
||||
// Skip derived `operand_segment_sizes` attribute as custom print format of
|
||||
// operands holds enough information to calculate these variadic operand list
|
||||
// lengths.
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/ArrayRef<StringRef>{
|
||||
p->printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/ArrayRef<StringRef>{
|
||||
kOperandSegmentSizesAttr});
|
||||
p->printRegion(op.body(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
@ -213,7 +213,7 @@ LogicalResult Verify(GraphOp graph) {
|
||||
void Print(GraphOp graph, OpAsmPrinter &p) {
|
||||
p << graph.getOperationName();
|
||||
p.printRegion(graph.getOperation()->getRegion(0));
|
||||
p.printOptionalAttrDict(graph.getAttrs());
|
||||
p.printOptionalAttrDict(graph->getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -321,7 +321,7 @@ void Print(IslandOp op, OpAsmPrinter &p) {
|
||||
// Check if we can print the short "wraps" form: that is if the island
|
||||
// contains a single operation and the result of this operation are perfectly
|
||||
// forwarded to the yield.
|
||||
if (op.getAttrs().empty() && op.WrapsSingleOp()) {
|
||||
if (op->getAttrs().empty() && op.WrapsSingleOp()) {
|
||||
Operation &wrapped_op = op.GetBody().front();
|
||||
YieldOp yield_op = op.GetYield();
|
||||
// The "wraps" syntax only encodes a single location.
|
||||
@ -335,7 +335,7 @@ void Print(IslandOp op, OpAsmPrinter &p) {
|
||||
}
|
||||
}
|
||||
p.printRegion(op.getOperation()->getRegion(0));
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -449,7 +449,7 @@ void Print(SwitchOp switch_op, OpAsmPrinter &p) {
|
||||
} else {
|
||||
p << switch_op.getType(0);
|
||||
}
|
||||
p.printOptionalAttrDict(switch_op.getAttrs());
|
||||
p.printOptionalAttrDict(switch_op->getAttrs());
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
@ -525,7 +525,7 @@ void Print(SwitchNOp switchn, OpAsmPrinter &p) {
|
||||
p << ")";
|
||||
}
|
||||
p << " : " << switchn.getType(0);
|
||||
p.printOptionalAttrDict(switchn.getAttrs(), {"num_outs"});
|
||||
p.printOptionalAttrDict(switchn->getAttrs(), {"num_outs"});
|
||||
}
|
||||
|
||||
ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -655,7 +655,7 @@ void Print(MergeOp merge, OpAsmPrinter &p) {
|
||||
p << output_type;
|
||||
}
|
||||
|
||||
p.printOptionalAttrDict(merge.getAttrs());
|
||||
p.printOptionalAttrDict(merge->getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseMergeOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -723,7 +723,7 @@ void Print(EnterOp enter, OpAsmPrinter &p) {
|
||||
p << enter.getType(0);
|
||||
}
|
||||
|
||||
p.printOptionalAttrDict(enter.getAttrs(),
|
||||
p.printOptionalAttrDict(enter->getAttrs(),
|
||||
{"frame_name", "parallel_iterations", "is_constant"});
|
||||
}
|
||||
|
||||
@ -843,7 +843,7 @@ void Print(ExitOp exit, OpAsmPrinter &p) {
|
||||
p << exit.getOperationName() << ' ';
|
||||
p.printOperands(exit.getOperands());
|
||||
p << " : " << exit.getType(0);
|
||||
p.printOptionalAttrDict(exit.getAttrs());
|
||||
p.printOptionalAttrDict(exit->getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) {
|
||||
@ -887,7 +887,7 @@ void Print(LoopCondOp loop_cond, OpAsmPrinter &p) {
|
||||
p << " : " << loop_cond.input().getType();
|
||||
}
|
||||
|
||||
p.printOptionalAttrDict(loop_cond.getAttrs());
|
||||
p.printOptionalAttrDict(loop_cond->getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseLoopCondOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
@ -3887,6 +3887,8 @@ Returns the input tensor otherwise.
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> {
|
||||
@ -8563,6 +8565,28 @@ the result here is consistent with a truncating divide. E.g.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ModelDatasetOp : TF_Op<"ModelDataset", [NoSideEffect]> {
|
||||
let summary = "Identity transformation that models performance.";
|
||||
|
||||
let description = [{
|
||||
Identity transformation that models performance.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_VariantTensor, [{A variant tensor representing the input dataset.}]>:$input_dataset,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "0">:$algorithm,
|
||||
DefaultValuedAttr<I64Attr, "0">:$cpu_budget,
|
||||
DefaultValuedAttr<I64Attr, "0">:$ram_budget,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x * y element-wise.";
|
||||
@ -9217,6 +9241,31 @@ def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAnd
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_OptimizeDatasetV2Op : TF_Op<"OptimizeDatasetV2", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Creates a dataset by applying related optimizations to `input_dataset`.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Creates a dataset by applying related optimizations to `input_dataset`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_VariantTensor, [{A variant tensor representing the input dataset.}]>:$input_dataset,
|
||||
Arg<TF_StrTensor, [{A `tf.string` vector `tf.Tensor` identifying user enabled optimizations.}]>:$optimizations_enabled,
|
||||
Arg<TF_StrTensor, [{A `tf.string` vector `tf.Tensor` identifying user disabled optimizations.}]>:$optimizations_disabled,
|
||||
Arg<TF_StrTensor, [{A `tf.string` vector `tf.Tensor` identifying optimizations by default.}]>:$optimizations_default,
|
||||
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
|
||||
DefaultValuedAttr<StrArrayAttr, "{}">:$optimization_configs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_OptionalGetValueOp : TF_Op<"OptionalGetValue", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns the value stored in an Optional variant or raises an error if none exists.
|
||||
@ -9305,6 +9354,8 @@ This is the opposite of `unpack`.
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
@ -888,7 +888,7 @@ class CaseOrIfRegionEliminatePassThrough
|
||||
|
||||
// Create new case/if region op.
|
||||
auto new_op = rewriter.create<CaseOrIfRegionOp>(
|
||||
op.getLoc(), new_result_types, op.getOperand(), op.getAttrs(),
|
||||
op.getLoc(), new_result_types, op.getOperand(), op->getAttrs(),
|
||||
op.getNumRegions());
|
||||
|
||||
int next_index = 0;
|
||||
@ -2092,6 +2092,19 @@ static LogicalResult Verify(EmptyTensorListOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EnsureShapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult EnsureShapeOp::fold(llvm::ArrayRef<mlir::Attribute>) {
|
||||
ShapedType type = input().getType().dyn_cast<ShapedType>();
|
||||
if (!type || !type.hasRank()) return {};
|
||||
// If shape attribute equals input operand's type's shape, fold it to input.
|
||||
if (type.getShape() == shape()) return input();
|
||||
// Else retain to enable failing dynamically.
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EqualOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -302,6 +302,49 @@ OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
|
||||
return slice_op.input();
|
||||
}
|
||||
|
||||
// Convert Pack to Reshape when there is only one operand to be packed.
|
||||
// For example,
|
||||
//
|
||||
// %0 = tf.Pack(%input) {axis = 0} // %input : tensor<2x3xf32>
|
||||
//
|
||||
// can be canonicalized to
|
||||
//
|
||||
// %shape = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>}
|
||||
// %0 = tf.Reshape(%input, %shape)
|
||||
struct ConvertPackToReshape : public OpRewritePattern<PackOp> {
|
||||
using OpRewritePattern<PackOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(PackOp pack_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check if there is only one operand to be packed.
|
||||
if (pack_op.N() != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check if input and output are static.
|
||||
auto input_ty = pack_op.getOperand(0).getType().cast<ShapedType>();
|
||||
auto output_ty = pack_op.output().getType().cast<ShapedType>();
|
||||
if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Create constant shape for reshape.
|
||||
auto type =
|
||||
RankedTensorType::get(output_ty.getRank(), rewriter.getIntegerType(64));
|
||||
auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape());
|
||||
auto shape = rewriter.create<ConstOp>(pack_op.getLoc(), shape_attr);
|
||||
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(pack_op, output_ty,
|
||||
pack_op.getOperand(0), shape);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ConvertPackToReshape>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2587,7 +2630,7 @@ class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
|
||||
TF::FusedBatchNormV3Op::getOperationName(),
|
||||
tf_fused_batch_norm_op.getOperands(),
|
||||
new_result_types,
|
||||
tf_fused_batch_norm_op.getAttrs());
|
||||
tf_fused_batch_norm_op->getAttrs());
|
||||
Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
|
||||
|
||||
rewriter.replaceOp(tf_fused_batch_norm_op,
|
||||
@ -3029,9 +3072,9 @@ struct WhileRegionEliminatePassThrough
|
||||
}
|
||||
|
||||
// Create the new while operation.
|
||||
auto new_while_op =
|
||||
rewriter.create<WhileRegionOp>(while_op.getLoc(), new_result_types,
|
||||
new_while_operands, while_op.getAttrs());
|
||||
auto new_while_op = rewriter.create<WhileRegionOp>(
|
||||
while_op.getLoc(), new_result_types, new_while_operands,
|
||||
while_op->getAttrs());
|
||||
|
||||
// Move region bodies to the new while.
|
||||
rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
|
||||
|
@ -1643,3 +1643,39 @@ func @testFoldStridedSliceShapeWithEmptySlice(%arg0: tensor<?x1x2x3xf32>) -> (te
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldEnsureShapeOp
|
||||
func @testFoldEnsureShapeOp(%arg0: tensor<10x20xf32>) -> (tensor<10x20xf32>, tensor<20x10xf32>) {
|
||||
%0 = "tf.EnsureShape"(%arg0) {shape = #tf.shape<10x20>} : (tensor<10x20xf32>) -> tensor<10x20xf32>
|
||||
// Failing case which should not be folded.
|
||||
// CHECK: %[[NF:.*]] = "tf.EnsureShape"(%arg0) {shape = #tf.shape<20x10>}
|
||||
%1 = "tf.EnsureShape"(%arg0) {shape = #tf.shape<20x10>} : (tensor<10x20xf32>) -> tensor<20x10xf32>
|
||||
// CHECK: return %arg0, %[[NF]]
|
||||
return %0, %1: tensor<10x20xf32>, tensor<20x10xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testConvertPackToReshapeAxis0
|
||||
func @testConvertPackToReshapeAxis0(%arg0: tensor<2x3xf32>) -> tensor<1x2x3xf32> {
|
||||
%0 = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x3xf32>) -> tensor<1x2x3xf32>
|
||||
return %0 : tensor<1x2x3xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
|
||||
// CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x3xf32>, tensor<3xi64>) -> tensor<1x2x3xf32>
|
||||
// CHECK: return %[[RESHAPE]] : tensor<1x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testConvertPackToReshapeAxis1
|
||||
func @testConvertPackToReshapeAxis1(%arg0: tensor<2x3xf32>) -> tensor<2x1x3xf32> {
|
||||
%0 = "tf.Pack"(%arg0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1x3xf32>
|
||||
return %0 : tensor<2x1x3xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[2, 1, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
|
||||
// CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x3xf32>, tensor<3xi64>) -> tensor<2x1x3xf32>
|
||||
// CHECK: return %[[RESHAPE]] : tensor<2x1x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testDontConvertPackToReshapeDynamicShape
|
||||
func @testDontConvertPackToReshapeDynamicShape(%arg0: tensor<2x?xf32>) -> tensor<1x2x?xf32> {
|
||||
%0 = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x?xf32>) -> tensor<1x2x?xf32>
|
||||
return %0 : tensor<1x2x?xf32>
|
||||
// CHECK: %[[PACK:.*]] = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x?xf32>) -> tensor<1x2x?xf32>
|
||||
// CHECK: return %[[PACK]] : tensor<1x2x?xf32>
|
||||
}
|
||||
|
@ -1,10 +1,29 @@
|
||||
// RUN: tf-opt %s -tf-drop-while-shape-invariant | FileCheck %s
|
||||
// RUN: tf-opt %s -tf-drop-while-shape-invariant-in-device-cluster | FileCheck -check-prefix=IN-CLUSTER %s
|
||||
|
||||
// CHECK-LABEL: while_shape_invariant
|
||||
|
||||
func @while_cond(%arg0: tensor<*xf32>) -> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @while_body(%arg0: tensor<*xf32>) -> (tensor<*xf32>) {
|
||||
%0 = "tf.SomeOp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// Test that -tf-drop-while-shape-invariant-in-device-cluster pass does not drop
|
||||
// the shape_invariant attribute from While/WhileRegion ops outside the device
|
||||
// cluster, while the other pass drops them.
|
||||
|
||||
// CHECK-LABEL: while_shape_invariant_outside_cluster
|
||||
// CHECK-NOT: shape_invariant
|
||||
func @while_shape_invariant(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
// IN-CLUSTER-LABEL: while_shape_invariant_outside_cluster
|
||||
func @while_shape_invariant_outside_cluster(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
// IN-CLUSTER: shape_invariant
|
||||
%0 = "tf.While"(%arg0) {cond = @while_cond, body = @while_body, is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
|
||||
|
||||
// IN-CLUSTER: shape_invariant
|
||||
%1 = "tf.WhileRegion"(%arg0) ( {
|
||||
^cond(%carg0: tensor<*xf32>):
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
@ -18,12 +37,28 @@ func @while_shape_invariant(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf3
|
||||
return %0, %1 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
func @while_cond(%arg0: tensor<*xf32>) -> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
// Test that both passes drop the shape_invariant attribute from
|
||||
// While/WhileRegion ops within a cluster.
|
||||
|
||||
func @while_body(%arg0: tensor<*xf32>) -> (tensor<*xf32>) {
|
||||
%0 = "tf.SomeOp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: while_shape_invariant_within_cluster
|
||||
// CHECK-NOT: shape_invariant
|
||||
// IN-CLUSTER-LABEL: while_shape_invariant_within_cluster
|
||||
// IN-CLUSTER-NOT: shape_invariant
|
||||
func @while_shape_invariant_within_cluster(%arg0: tensor<4xf32>) {
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.While"(%arg0) {cond = @while_cond, body = @while_body, is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
|
||||
|
||||
%1 = "tf.WhileRegion"(%arg0) ( {
|
||||
^cond(%carg0: tensor<*xf32>):
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
"tf.Yield"(%2) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^body(%barg0: tensor<*xf32>):
|
||||
%2 = "tf.SomeOp"(%barg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}) {is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
|
||||
tf_device.return
|
||||
}) {} : () -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -213,11 +213,11 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) {
|
||||
} : (tensor<*xf32>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: [[Result0:%.*]] = "tf.WhileRegion"
|
||||
// CHECK: [[ResultCast0:%.*]] = "tf.Cast"
|
||||
// CHECK: [[Result1:%.*]] = call @testWhileCond([[ResultCast0]])
|
||||
// CHECK: ^bb0(%[[CARG0:.*]]: tensor<4xf32>
|
||||
// CHECK: [[Result1:%.*]] = call @testWhileCond(%[[CARG0]])
|
||||
// CHECK: "tf.Yield"([[Result1]])
|
||||
// CHECK: [[ResultCast1:%.*]] = "tf.Cast"
|
||||
// CHECK: [[Result2:%.*]] = call @testWhileBody([[ResultCast1]])
|
||||
// CHECK: ^bb0(%[[BARG0:.*]]: tensor<4xf32>
|
||||
// CHECK: [[Result2:%.*]] = call @testWhileBody(%[[BARG0]])
|
||||
// CHECK: "tf.Yield"([[Result2]])
|
||||
// CHECK: return [[Result0]]
|
||||
return %1 : tensor<*xf32>
|
||||
|
@ -0,0 +1,34 @@
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [
|
||||
":debug_info_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["pbtxt"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-mlir-translate",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
# Bundle together all the debug info files that are used by the tests.
|
||||
filegroup(
|
||||
name = "debug_info_files",
|
||||
srcs = glob(
|
||||
[
|
||||
"**/*.debug",
|
||||
],
|
||||
),
|
||||
)
|
@ -0,0 +1,137 @@
|
||||
#RUN: tf-mlir-translate --savedmodel-signaturedefs-to-mlir-lite -tf-savedmodel-tags=serve,tpu %p | FileCheck %s
|
||||
|
||||
# Test importing a saved model with 2 signatures that are using a same
|
||||
# BatchFunction Op, which references to a same inference_func from graph_def
|
||||
# library. The result should be that both signatures uses the same
|
||||
# BatchFunction Op (the shared_name is the same) and the same copy of
|
||||
# inference_func.
|
||||
|
||||
# CHECK: func @predict0
|
||||
# CHECK: f = @inference_func[[post_fix:[^,]*]]
|
||||
# CHECK-SAME: shared_name = "batch"
|
||||
|
||||
# CHECK: func @predict1
|
||||
# CHECK: f = @inference_func[[post_fix]],
|
||||
# CHECK-SAME: shared_name = "batch"
|
||||
|
||||
meta_graphs: {
|
||||
meta_info_def: {
|
||||
tags: ["serve", "tpu"]
|
||||
}
|
||||
graph_def: {
|
||||
node: {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "batch_func"
|
||||
op: "BatchFunction"
|
||||
input: ["input1"]
|
||||
attr: {
|
||||
key: "Tcaptured"
|
||||
value: {
|
||||
list: {
|
||||
type: []
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "Tin"
|
||||
value: {
|
||||
list: {
|
||||
type: [DT_INT32]
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "Tout"
|
||||
value: {
|
||||
list: {
|
||||
type: [DT_FLOAT, DT_FLOAT]
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "f"
|
||||
value: {
|
||||
func: {
|
||||
name: "inference_func"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "shared_name"
|
||||
value: {
|
||||
s: "batch"
|
||||
}
|
||||
}
|
||||
}
|
||||
library: {
|
||||
function {
|
||||
signature {
|
||||
name: "inference_func"
|
||||
input_arg {
|
||||
name: "arg0"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "retval0"
|
||||
value: "arg0"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
signature_def: {
|
||||
key: "predict0"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value: {
|
||||
name: "input0"
|
||||
dtype: DT_STRING
|
||||
}
|
||||
}
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value: {
|
||||
name: "batch_func:0"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
signature_def: {
|
||||
key: "predict1"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "tf_example_input"
|
||||
value: {
|
||||
name: "input0"
|
||||
dtype: DT_STRING
|
||||
}
|
||||
}
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value: {
|
||||
name: "batch_func:1"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,39 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
name: "func0"
|
||||
op: "func_name"
|
||||
input: "input"
|
||||
}
|
||||
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "func_name"
|
||||
input_arg {
|
||||
name: "arg0"
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "retval0"
|
||||
value: "arg0"
|
||||
}
|
||||
attr: {
|
||||
key: "_input_shapes"
|
||||
value: {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,76 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Sub -o - | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "input0"
|
||||
input: "input1"
|
||||
# If device type or id doesn't exist, assign a default one (device:CPU:0).
|
||||
device: "/job:localhost/replica:0/task:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Mul"
|
||||
op: "Mul"
|
||||
input: "Add"
|
||||
input: "Add"
|
||||
# Empty device name should be kept untouched.
|
||||
device: ""
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Sub"
|
||||
op: "Sub"
|
||||
input: "Add"
|
||||
input: "Mul"
|
||||
# Device name is not modified if complete
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<*xi32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "input0,input1"
|
||||
# CHECK-SAME: outputs = "Sub"
|
||||
# CHECK: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
|
||||
# CHECK: %[[mul:.*]], %[[mul_control:.*]] = tf_executor.island wraps "tf.Mul"(%[[add]], %[[add]]) {device = ""}
|
||||
# CHECK: %[[sub:.*]], %[[sub_control:.*]] = tf_executor.island wraps "tf.Sub"(%[[add]], %[[mul]]) {device = "/job:localhost/replica:0/task:0/device:CPU:1"}
|
@ -126,3 +126,20 @@ func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<
|
||||
|
||||
return %2#0 : tensor<1x112x112x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_into_pad_with_extra_uses
|
||||
func @fold_into_pad_with_extra_uses(%arg0: tensor<1x2x4x4x3xf32>) -> (tensor<1x2x3x4x4xf32>, tensor<1x2x3x6x6xf32>) {
|
||||
|
||||
// CHECK: %[[PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>}
|
||||
// CHECK: %[[TRANSPOSE_OP:[0-9]*]] = "tf.Transpose"(%arg0, %[[PERM]])
|
||||
// CHECK: %[[PADDING:[0-9]*]] = "tf.Const"() {value = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<5x2xi32>}
|
||||
// CHECK: %[[PAD_OP:[0-9]*]] = "tf.Pad"(%arg0, %[[PADDING]])
|
||||
// CHECK: %[[DUP_TRANSPOSE_OP:[0-9]*]] = "tf.Transpose"(%[[PAD_OP]], %[[PERM]])
|
||||
// CHECK: return %[[TRANSPOSE_OP]], %[[DUP_TRANSPOSE_OP]]
|
||||
|
||||
%0 = "tf.Const"() {value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
|
||||
%1 = "tf.Const"() {value = dense<[[0, 0], [0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<5x2xi32>} : () -> tensor<5x2xi32>
|
||||
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x2x4x4x3xf32>, tensor<5xi32>) -> tensor<1x2x3x4x4xf32>
|
||||
%3 = "tf.Pad"(%2, %1) : (tensor<1x2x3x4x4xf32>, tensor<5x2xi32>) -> tensor<1x2x3x6x6xf32>
|
||||
return %2, %3 : tensor<1x2x3x4x4xf32>, tensor<1x2x3x6x6xf32>
|
||||
}
|
||||
|
@ -1910,6 +1910,21 @@ func @convert_gather_nd(%arg0: tensor<98x128xf32>, %arg1: tensor<4x64xi32>) -> t
|
||||
return %0 : tensor<4x64x128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @convert_gather_transpose(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x256xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>) -> tensor<4x128xf32> {
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.Const"{{.*}}value = dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<256x128xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tf.GatherNd"(%[[VAL_3]], %[[VAL_1]]) : {{.*}} -> tensor<4x128xf32>
|
||||
// CHECK: return %[[VAL_4]]
|
||||
// CHECK: }
|
||||
// Test the case when start_index_map isn't an iota what requires a transpose to
|
||||
// convert it to tf.GatherNd.
|
||||
func @convert_gather_transpose(%arg0: tensor<128x256xf32>, %arg1: tensor<4x1xi32>) -> tensor<4x128xf32> {
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<1> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<1> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[128, 1]> : tensor<2xi64>} : (tensor<128x256xf32>, tensor<4x1xi32>) -> tensor<4x128xf32>
|
||||
return %0 : tensor<4x128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @convert_dynamic_slice(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<7x3xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<i32>,
|
||||
|
@ -0,0 +1,20 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef -tf-export-entry-func-to-flib %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 458 : i32}} {
|
||||
func @main() {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island wraps "tf.Const"() {device = "TPU:0", name = "const", dtype = "tfdtype$DT_INT32", value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-NOT: node
|
||||
|
||||
// CHECK: library
|
||||
// CHECK-NEXT: function
|
||||
// CHECK-NEXT: signature
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK: node_def
|
||||
// CHECK: op: "Const"
|
@ -9,7 +9,7 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: Functions must be of a single Graph with single op Islands: only single block functions are supported.
|
||||
// CHECK: functions must be of a single Graph with single op Islands: only single block functions are supported
|
||||
|
||||
// -----
|
||||
|
||||
@ -19,7 +19,7 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: Functions must be of a single Graph with single op Islands: first op in function is not a tf_executor.graph.
|
||||
// CHECK: functions must be of a single Graph with single op Islands: first op in function is not a tf_executor.graph
|
||||
|
||||
// -----
|
||||
|
||||
@ -33,7 +33,7 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: Functions must be of a single Graph with single op Islands: function does not only contain a single tf_executor.graph.
|
||||
// CHECK: functions must be of a single Graph with single op Islands: function does not only contain a single tf_executor.graph
|
||||
|
||||
// -----
|
||||
|
||||
@ -47,7 +47,7 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op.
|
||||
// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op
|
||||
|
||||
// -----
|
||||
|
||||
@ -63,7 +63,7 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op.
|
||||
// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op
|
||||
|
||||
// -----
|
||||
|
||||
@ -78,7 +78,7 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op.
|
||||
// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op
|
||||
|
||||
// -----
|
||||
|
||||
@ -93,4 +93,4 @@ func @main(%arg0: tensor<i32>, %arg1: tensor<i32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: Functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op.
|
||||
// CHECK: functions must be of a single Graph with single op Islands: tf_executor.island must perfectly wrap a single op
|
||||
|
@ -470,8 +470,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
// CHECK: tf.TensorListSetItem{{.*}}: (tensor<!tf.variant<tensor<2x2xf32>>>, tensor<i32>, tensor<2x2xf32>) -> tensor<!tf.variant<tensor<2x2xf32>>>
|
||||
%6 = "tf.TensorListSetItem"(%3, %4, %5) {device = ""} : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>, tensor<2x2xf32>)-> tensor<*x!tf.variant>
|
||||
%7 = "tf.Const"() {device = "", value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%8 = "tf.StopGradient"(%6) : (tensor<*x!tf.variant>) -> tensor<*x!tf.variant>
|
||||
// CHECK: tf.TensorListStack{{.*}}: (tensor<!tf.variant<tensor<2x2xf32>>>, tensor<i32>) -> tensor<?x2x2xf32>
|
||||
%8 = "tf.TensorListStack"(%6, %7) {device = "", num_elements = -1 : i64} : (tensor<*x!tf.variant>, tensor<i32>) -> tensor<*xf32>
|
||||
%9 = "tf.TensorListStack"(%8, %7) {device = "", num_elements = -1 : i64} : (tensor<*x!tf.variant>, tensor<i32>) -> tensor<*xf32>
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
|
@ -27,19 +27,16 @@ from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||
|
||||
# CHECK: func {{.*}} tf_saved_model.exported_names = ["key_1"]
|
||||
# CHECK: "tf.If"
|
||||
# CHECK-SAME: else_branch = @[[else_1:"key_1/[a-zA-Z_0-9]+"]]
|
||||
# CHECK-SAME: then_branch = @[[then_1:"key_1/[a-zA-Z_0-9]+"]]
|
||||
# CHECK-SAME: else_branch = @[[else:[a-zA-Z_0-9]+]]
|
||||
# CHECK-SAME: then_branch = @[[then:[a-zA-Z_0-9]+]]
|
||||
|
||||
# CHECK: func {{.*}} tf_saved_model.exported_names = ["key_2"]
|
||||
# CHECK: "tf.If"
|
||||
# CHECK-SAME: else_branch = @[[else_2:"key_2/[a-zA-Z_0-9]+"]]
|
||||
# CHECK-SAME: then_branch = @[[then_2:"key_2/[a-zA-Z_0-9]+"]]
|
||||
# CHECK-SAME: else_branch = @[[else]]
|
||||
# CHECK-SAME: then_branch = @[[then]]
|
||||
|
||||
# CHECK: func private @[[else_1]](
|
||||
# CHECK: func private @[[then_1]](
|
||||
|
||||
# CHECK: func private @[[else_2]](
|
||||
# CHECK: func private @[[then_2]](
|
||||
# CHECK: func private @[[else]](
|
||||
# CHECK: func private @[[then]](
|
||||
|
||||
|
||||
def Test():
|
||||
|
@ -39,21 +39,8 @@ void EnableLogging(PassManager *pm) {
|
||||
} // namespace
|
||||
|
||||
namespace TFTPU {
|
||||
|
||||
namespace {
|
||||
void AddGraphExportLoweringPasses(OpPassManager &pm) {
|
||||
auto add_pass = [&](std::unique_ptr<Pass> pass) {
|
||||
pm.addNestedPass<FuncOp>(std::move(pass));
|
||||
pm.addPass(CreateBreakUpIslandsPass());
|
||||
};
|
||||
|
||||
add_pass(CreateFunctionalToExecutorDialectConversionPass());
|
||||
add_pass(TFDevice::CreateReplicateToIslandPass());
|
||||
add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
|
||||
add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
|
||||
pm.addNestedPass<FuncOp>(CreateTPUDevicePropagationPass());
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
}
|
||||
|
||||
tensorflow::Status RunTPUBridge(
|
||||
ModuleOp module, bool enable_logging,
|
||||
llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
|
||||
@ -68,7 +55,7 @@ tensorflow::Status RunTPUBridge(
|
||||
pipeline_builder(bridge);
|
||||
|
||||
// Add set of passes to lower back to graph (from tf_executor).
|
||||
AddGraphExportLoweringPasses(bridge);
|
||||
TF::AddGraphExportLoweringPasses(bridge);
|
||||
|
||||
// Run the bridge on the module, in case of failure, the `diag_handler`
|
||||
// converts MLIR errors emitted to the MLIRContext into a tensorflow::Status.
|
||||
@ -110,13 +97,19 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
|
||||
func_pm.addPass(CreateTPUHostComputationExpansionPass());
|
||||
func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
|
||||
}
|
||||
// Run another shape inference pass because resource decomposition might have
|
||||
// created new partial types.
|
||||
pm.addPass(TF::CreateTFShapeInferencePass());
|
||||
|
||||
// Note that the region-based control-flow produced here still contains
|
||||
// function call ops which get inlined by the subsequent inliner pass.
|
||||
pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
pm.addNestedPass<FuncOp>(
|
||||
TF::CreateDropWhileShapeInvariantInDeviceClusterPass());
|
||||
// Run another shape inference pass because resource decomposition might have
|
||||
// created new partial types. Also, after dropping `shape_invariant` attribute
|
||||
// from While/WhileRegion ops within cluster would lead to more precise
|
||||
// shapes.
|
||||
pm.addPass(TF::CreateTFShapeInferencePass());
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addPass(CreateTPUClusterCleanupAttributesPass());
|
||||
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
@ -166,6 +159,21 @@ tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging) {
|
||||
|
||||
namespace TF {
|
||||
|
||||
void AddGraphExportLoweringPasses(OpPassManager &pm) {
|
||||
auto add_pass = [&](std::unique_ptr<Pass> pass) {
|
||||
pm.addNestedPass<FuncOp>(std::move(pass));
|
||||
pm.addPass(CreateBreakUpIslandsPass());
|
||||
};
|
||||
|
||||
add_pass(CreateFunctionalToExecutorDialectConversionPass());
|
||||
add_pass(TFDevice::CreateReplicateToIslandPass());
|
||||
add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
|
||||
add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
|
||||
pm.addPass(TFTPU::CreateTPUDevicePropagationPass());
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
pm.addPass(CreateVerifySuitableForExportPass());
|
||||
}
|
||||
|
||||
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
|
||||
bool enable_logging,
|
||||
bool enable_inliner) {
|
||||
|
@ -20,20 +20,32 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
extern void AddGraphExportLoweringPasses(OpPassManager &pm);
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
|
||||
// Registers an existing pipeline builder function.
|
||||
// Registers a pipeline builder function for TF TPU bridge.
|
||||
mlir::PassPipelineRegistration<> tpu_pipeline(
|
||||
"tf-tpu-bridge",
|
||||
"Run all the passes involved in transforming the graph before execution so "
|
||||
"that it is suitable for targeting TPUs.",
|
||||
mlir::TFTPU::CreateTPUBridgePipeline);
|
||||
|
||||
// Registers an existing pipeline builder function.
|
||||
// Registers a pipeline builder function for TF TPU V1 bridge.
|
||||
mlir::PassPipelineRegistration<> tpu_pipeline_v1(
|
||||
"tf-tpu-bridge-v1",
|
||||
"Run all the passes involved in transforming a TensorFlow V1 graph before "
|
||||
"execution so that it is suitable for targeting TPUs.",
|
||||
mlir::TFTPU::CreateTPUBridgePipelineV1);
|
||||
|
||||
// Registers a pipeline builder function for TF Graph export.
|
||||
mlir::PassPipelineRegistration<> tpu_export(
|
||||
"tf-graph-export",
|
||||
"Run passes to prepare for exporting module back to TF Graph.",
|
||||
mlir::TF::AddGraphExportLoweringPasses);
|
||||
|
||||
} // anonymous namespace
|
||||
|
@ -112,7 +112,7 @@ void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
|
||||
builder->setInsertionPoint(cluster_op);
|
||||
auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
|
||||
cluster_op.getLoc(), outlined_func.getType().getResults(),
|
||||
live_ins.getArrayRef(), cluster_op.getAttrs());
|
||||
live_ins.getArrayRef(), cluster_op->getAttrs());
|
||||
|
||||
cluster_op.replaceAllUsesWith(cluster_func_op);
|
||||
cluster_op.erase();
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -32,22 +33,50 @@ class DropWhileShapeInvariantPass
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Drop `shape_invariant` attribute from tf.While and tf.WhileRegion op only
|
||||
// inside device cluster. This would allow shape inference pass to further
|
||||
// refine operand/result shapes of these ops. This is only safe to do when
|
||||
// compiling to XLA.
|
||||
class DropWhileShapeInvariantInDeviceClusterPass
|
||||
: public PassWrapper<DropWhileShapeInvariantInDeviceClusterPass,
|
||||
FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void DropWhileShapeInvariantAttr(Operation* op) {
|
||||
if (llvm::isa<WhileOp, WhileRegionOp>(op))
|
||||
op->removeAttr(kShapeInvariantAttr);
|
||||
}
|
||||
void DropWhileShapeInvariantPass::runOnFunction() {
|
||||
getFunction().walk([](Operation* op) {
|
||||
if (llvm::isa<WhileOp, WhileRegionOp>(op))
|
||||
op->removeAttr(kShapeInvariantAttr);
|
||||
getFunction().walk([](Operation* op) { DropWhileShapeInvariantAttr(op); });
|
||||
}
|
||||
|
||||
void DropWhileShapeInvariantInDeviceClusterPass::runOnFunction() {
|
||||
getFunction().walk([](tf_device::ClusterOp cluster) {
|
||||
cluster.walk([](Operation* op) { DropWhileShapeInvariantAttr(op); });
|
||||
});
|
||||
}
|
||||
|
||||
static PassRegistration<DropWhileShapeInvariantPass> pass(
|
||||
static PassRegistration<DropWhileShapeInvariantPass> drop_shape_invariant_pass(
|
||||
"tf-drop-while-shape-invariant",
|
||||
"Drop `shape_invariant` attrbute from While/WhileRegion ops.");
|
||||
|
||||
static PassRegistration<DropWhileShapeInvariantInDeviceClusterPass>
|
||||
drop_shape_invariant_in_cluster_pass(
|
||||
"tf-drop-while-shape-invariant-in-device-cluster",
|
||||
"Drop `shape_invariant` attrbute from While/WhileRegion ops inside "
|
||||
"device cluster.");
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDropWhileShapeInvariantPass() {
|
||||
return std::make_unique<DropWhileShapeInvariantPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateDropWhileShapeInvariantInDeviceClusterPass() {
|
||||
return std::make_unique<DropWhileShapeInvariantInDeviceClusterPass>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -63,7 +63,7 @@ YieldOp CreateCall(Operation* op, FuncOp func, Region& caller_region,
|
||||
Block* entry = builder.createBlock(&caller_region);
|
||||
|
||||
if (use_region_args) {
|
||||
entry->addArguments(args.getType());
|
||||
entry->addArguments(func.getType().getInputs());
|
||||
args = entry->getArguments();
|
||||
}
|
||||
llvm::SmallVector<Value, 4> casted_args;
|
||||
|
@ -156,7 +156,7 @@ class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
|
||||
// The fused contraction has the same attributes as the original
|
||||
// contraction, with two additions: the list of ops which have been fused
|
||||
// together; epsilon (only with FusedBatchNorm).
|
||||
std::vector<NamedAttribute> attrs = contraction.getAttrs();
|
||||
std::vector<NamedAttribute> attrs = contraction->getAttrs();
|
||||
ArrayAttr fused_ops_attr = ArrayAttr::get(context, fused_ops);
|
||||
attrs.push_back(
|
||||
NamedAttribute(Identifier::get("fused_ops", context), fused_ops_attr));
|
||||
|
@ -96,7 +96,7 @@ struct ReluToFusedBatchNorm : public OpRewritePattern<ReluOp> {
|
||||
state.addOperands(batch_norm.getOperands());
|
||||
if (side_input) state.operands.push_back(side_input);
|
||||
state.addTypes(batch_norm.getResultTypes());
|
||||
state.addAttributes(batch_norm.getAttrs());
|
||||
state.addAttributes(batch_norm->getAttrs());
|
||||
Operation *op = rewriter.createOperation(state);
|
||||
rewriter.replaceOp(batch_norm, op->getResults());
|
||||
|
||||
|
@ -387,7 +387,7 @@ void MoveTransposeAfter(Operation* op, SmallVector<Operation*, 8>* work_list,
|
||||
// Maybe add Transpose nodes for layout dependent results
|
||||
// (or reuse existing transposes).
|
||||
OpBuilder builder(op);
|
||||
builder.setInsertionPoint(op);
|
||||
builder.setInsertionPointAfter(op);
|
||||
|
||||
for (unsigned idx : layout_dependent_results) {
|
||||
OpResult result = op->getResult(idx);
|
||||
|
@ -1198,18 +1198,22 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Verify that start_index_map and collapsed_slice_dims are both an iota
|
||||
// with the same number of elements as the last dimension of start_indices.
|
||||
// Verify that start_index_map and collapsed_slice_dims contains the same
|
||||
// values.
|
||||
auto start_index_map = gather_op.dimension_numbers().start_index_map();
|
||||
auto collapsed_slice_dims =
|
||||
gather_op.dimension_numbers().collapsed_slice_dims();
|
||||
if (!IsIotaAttr(start_index_map, start_indices_type.getShape().back()) ||
|
||||
!IsIotaAttr(collapsed_slice_dims,
|
||||
start_indices_type.getShape().back())) {
|
||||
// TODO(tberghammer): Transform start_indices to support non-standard
|
||||
// start_index_maps.
|
||||
if (start_index_map.getNumElements() !=
|
||||
collapsed_slice_dims.getNumElements()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
gather_op, "unsupported start index map and/or collapsed slice dims");
|
||||
gather_op,
|
||||
"different size for start index map and collapsed slice dims");
|
||||
}
|
||||
for (auto c : collapsed_slice_dims) {
|
||||
if (llvm::count(start_index_map, c) == 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
gather_op, "collapsed slice dim isn't present in start index map");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that slice_sizes is 1 for the indexed dimensions and the full
|
||||
@ -1217,7 +1221,7 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
|
||||
auto slice_sizes = gather_op.slice_sizes();
|
||||
int64_t index = 0;
|
||||
for (int64_t s : slice_sizes.getValues<int64_t>()) {
|
||||
if (index < start_indices_type.getShape().back()) {
|
||||
if (llvm::count(start_index_map, index)) {
|
||||
if (s != 1) {
|
||||
return rewriter.notifyMatchFailure(gather_op,
|
||||
"unsupported slice sizes");
|
||||
@ -1242,6 +1246,25 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
|
||||
++offset;
|
||||
}
|
||||
|
||||
// Transpose the operand to handle non-iota start index map.
|
||||
llvm::SmallVector<int64_t, 4> transpose_dimensions;
|
||||
llvm::SmallVector<int64_t, 4> transpose_shape;
|
||||
for (auto s : start_index_map) {
|
||||
transpose_dimensions.push_back(s.getZExtValue());
|
||||
transpose_shape.push_back(operand_type.getShape()[s.getZExtValue()]);
|
||||
}
|
||||
for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) {
|
||||
if (llvm::count(start_index_map, i) == 0) {
|
||||
transpose_dimensions.push_back(i);
|
||||
transpose_shape.push_back(operand_type.getShape()[i]);
|
||||
}
|
||||
}
|
||||
operand_type =
|
||||
RankedTensorType::get(transpose_shape, operand_type.getElementType());
|
||||
operand = rewriter.create<mhlo::TransposeOp>(
|
||||
gather_op.getLoc(), operand_type, operand,
|
||||
rewriter.getI64TensorAttr(transpose_dimensions));
|
||||
|
||||
rewriter.replaceOpWithNewOp<TF::GatherNdOp>(gather_op, result_type, operand,
|
||||
start_indices);
|
||||
return success();
|
||||
|
@ -67,7 +67,6 @@ LogicalResult LiftVariablesFromSession(
|
||||
ModuleOp module, Session* session,
|
||||
const SmallSet<StringRef, 4>& resource_names) {
|
||||
OpBuilder builder(module.getBodyRegion());
|
||||
MLIRContext* context = module.getContext();
|
||||
|
||||
if (!session) return module.emitOpError() << "no session provided";
|
||||
|
||||
@ -137,7 +136,7 @@ LogicalResult LiftVariablesFromSession(
|
||||
ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
|
||||
|
||||
builder.create<tf_saved_model::GlobalTensorOp>(
|
||||
NameLoc::get(builder.getIdentifier(name.str()), context),
|
||||
NameLoc::get(builder.getIdentifier(name.str())),
|
||||
builder.getStringAttr(name), tensor_attr,
|
||||
TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr());
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ class ResourceAnalyzer {
|
||||
public:
|
||||
explicit ResourceAnalyzer(ModuleOp module) {
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
(void)AnalyzeFunc(func);
|
||||
(void)AnalyzeRegion(func.getRegion());
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,18 +82,18 @@ class ResourceAnalyzer {
|
||||
}
|
||||
|
||||
private:
|
||||
// Analyze the specified func for resource mutating operations, namely
|
||||
// Analyze the specified region for resource mutating operations, namely
|
||||
// TF::AssignVariableOp, if so, set the resource associated as "potentially
|
||||
// written". Do this recursively across the chain of funcs via call or control
|
||||
// flow ops.
|
||||
// written". Do this recursively across the chain of regions via call or
|
||||
// control flow ops.
|
||||
// TODO(ashwinm): Move to iterative traversal.
|
||||
LogicalResult AnalyzeFunc(FuncOp func) {
|
||||
LogicalResult AnalyzeRegion(Region& region) {
|
||||
// Avoid infinite recursion.
|
||||
if (!discovered_.insert(func).second) {
|
||||
if (!discovered_.insert(®ion).second) {
|
||||
return success();
|
||||
}
|
||||
|
||||
func.walk([&](Operation* op) {
|
||||
region.walk([&](Operation* op) {
|
||||
if (isa<TF::ReadVariableOp, ReturnOp>(op)) {
|
||||
return;
|
||||
}
|
||||
@ -103,23 +103,40 @@ class ResourceAnalyzer {
|
||||
}
|
||||
if (auto call = dyn_cast<CallOpInterface>(op)) {
|
||||
if (auto func = dyn_cast<FuncOp>(call.resolveCallable())) {
|
||||
PropagatePotentiallyWrittenUpFromCallee(func, call.getArgOperands());
|
||||
PropagatePotentiallyWrittenUpFromCallee(func.getRegion(),
|
||||
call.getArgOperands());
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
|
||||
for (auto callee : {if_op.then_function(), if_op.else_function()}) {
|
||||
PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input());
|
||||
PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
|
||||
if_op.input());
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (auto if_op = dyn_cast<TF::IfRegionOp>(op)) {
|
||||
PropagatePotentiallyWrittenUpFromCallee(if_op.then_branch(),
|
||||
if_op.getODSOperands(1));
|
||||
PropagatePotentiallyWrittenUpFromCallee(if_op.else_branch(),
|
||||
if_op.getODSOperands(1));
|
||||
return;
|
||||
}
|
||||
if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
|
||||
for (auto callee :
|
||||
{while_op.cond_function(), while_op.body_function()}) {
|
||||
PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input());
|
||||
PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
|
||||
while_op.input());
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
|
||||
PropagatePotentiallyWrittenUpFromCallee(while_op.cond(),
|
||||
while_op.input());
|
||||
PropagatePotentiallyWrittenUpFromCallee(while_op.body(),
|
||||
while_op.input());
|
||||
return;
|
||||
}
|
||||
// For all other ops, we assume it mutates all resources it uses, so
|
||||
// this errs on the side of being conservative. We should improve
|
||||
// this by using either a property or a trait that clearly
|
||||
@ -144,14 +161,14 @@ class ResourceAnalyzer {
|
||||
});
|
||||
}
|
||||
|
||||
// Given a FuncOp associated with the callee and operands from the
|
||||
// Given a Region associated with the callee and operands from the
|
||||
// corresponding callOp, propagate the potentially written decision to the
|
||||
// callOp's operands, if the corresponding func's arguments are potentially
|
||||
// callOp's operands, if the corresponding region's arguments are potentially
|
||||
// written resources.
|
||||
void PropagatePotentiallyWrittenUpFromCallee(
|
||||
FuncOp func, Operation::operand_range propagate_to) {
|
||||
(void)AnalyzeFunc(func);
|
||||
for (auto t : llvm::zip(func.getArguments(), propagate_to)) {
|
||||
Region& region, Operation::operand_range propagate_to) {
|
||||
(void)AnalyzeRegion(region);
|
||||
for (auto t : llvm::zip(region.getArguments(), propagate_to)) {
|
||||
if (!IsResource(std::get<0>(t))) {
|
||||
continue;
|
||||
}
|
||||
@ -172,8 +189,8 @@ class ResourceAnalyzer {
|
||||
// Value: Information we know about that Value.
|
||||
// Note that these Value's are in general in different functions.
|
||||
DenseMap<Value, ResourceInfo> resource_infos_;
|
||||
// The set of func's we already discovered.
|
||||
DenseSet<FuncOp> discovered_;
|
||||
// The set of regions we already discovered.
|
||||
DenseSet<Region*> discovered_;
|
||||
};
|
||||
|
||||
bool IsImmutable(GlobalTensorOp global_tensor,
|
||||
@ -231,7 +248,7 @@ void MarkGlobalTensorsImmutable(
|
||||
auto global_tensor = kv.first;
|
||||
const auto& global_tensor_uses = kv.second;
|
||||
if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
|
||||
global_tensor.removeAttr("is_mutable");
|
||||
global_tensor->removeAttr("is_mutable");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -43,6 +43,11 @@ namespace TF {
|
||||
// ops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDropWhileShapeInvariantPass();
|
||||
|
||||
// Creates a pass that drops `shape_invariant` attribute from While/WhileRegion
|
||||
// ops within device cluster.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateDropWhileShapeInvariantInDeviceClusterPass();
|
||||
|
||||
// Transforms functional control flow operations in the TensorFlow dialect to
|
||||
// MLIR Control Flow Graph (CFG) form.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();
|
||||
@ -203,6 +208,15 @@ std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass();
|
||||
// will replicate the tf.Const op once for each device.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateConstantOpDeviceAssignmentPass();
|
||||
|
||||
// Populates the supplied passmanager with the passes required to export
|
||||
// to TensorFlow Graph.
|
||||
void AddGraphExportLoweringPasses(OpPassManager& pm);
|
||||
|
||||
// Returns pass that verifies whether all functions in module are of single
|
||||
// tf_executor.graph and each tf_executor.island in tf_executor.graph only has a
|
||||
// single op.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateVerifySuitableForExportPass();
|
||||
|
||||
} // namespace TF
|
||||
|
||||
namespace tf_executor {
|
||||
|
@ -931,7 +931,7 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
while_op.getLoc(), body.getType().getResults(),
|
||||
FilterRange<Value, OperandRange>(while_op.getOperands(),
|
||||
resource_arg_uses),
|
||||
while_op.getAttrs());
|
||||
while_op->getAttrs());
|
||||
// Prepare for AddLoadsStoresOutsideControlFlowOp().
|
||||
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
|
||||
arg_data_type_and_updated_output_index;
|
||||
@ -1035,7 +1035,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<FuncOp> branches) {
|
||||
FuncOp first_func = branches.front();
|
||||
auto new_op =
|
||||
builder.create<CaseOrIfOp>(op.getLoc(), first_func.getType().getResults(),
|
||||
new_operands, op.getAttrs());
|
||||
new_operands, op->getAttrs());
|
||||
// Prepare for AddLoadsStoresOutsideControlFlowOp()
|
||||
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
|
||||
arg_data_type_and_updated_output_index;
|
||||
@ -1179,7 +1179,7 @@ void UpdatePartitionedCallOpWithNewCallee(
|
||||
FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
|
||||
auto new_call = builder.create<CallOpType>(
|
||||
call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
|
||||
new_operands, call_op.getAttrs());
|
||||
new_operands, call_op->getAttrs());
|
||||
new_call->setAttr(
|
||||
"f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
|
||||
AddLoadsStoresOutsideControlFlowOp(
|
||||
|
@ -172,121 +172,124 @@ RankedTensorType DropFirstDimension(Type type) {
|
||||
bool CanInferTensorListElementType(Value tensorlist,
|
||||
Value initial_element_shape,
|
||||
RankedTensorType* potential_element_type) {
|
||||
DCOMMENT("CanInferTensorListElementType " << tensorlist << " with initial "
|
||||
<< initial_element_shape);
|
||||
// Verifies if the new element type has static shape and matches the potential
|
||||
// type passed from caller. Updates the potential_element_type, if not defined
|
||||
// yet.
|
||||
auto verify_and_update_potential_element_type =
|
||||
[&](RankedTensorType new_element_type) -> bool {
|
||||
DCOMMENT("\t\tConsidering " << new_element_type << " with old "
|
||||
<< *potential_element_type);
|
||||
if (!new_element_type || !new_element_type.hasStaticShape()) return false;
|
||||
if (!*potential_element_type) {
|
||||
DCOMMENT("\t\tUpdating potential_element_type " << new_element_type);
|
||||
*potential_element_type = new_element_type;
|
||||
return true;
|
||||
}
|
||||
return *potential_element_type == new_element_type;
|
||||
};
|
||||
|
||||
// TensorLists are semantically immutable. For example, TensorListSetItem
|
||||
// takes a TensorList as input and produces a TensorList as output. So to
|
||||
// traverse modifications to TensorList and verify that all elements written
|
||||
// to it have the same shape, we need to follow use-def chain of ops that
|
||||
// (conceptually) modify it i.e., ops that take an input TensorList and
|
||||
// produce an output TensorList.
|
||||
for (auto& use : tensorlist.getUses()) {
|
||||
if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) {
|
||||
auto element_type = push.tensor().getType().dyn_cast<RankedTensorType>();
|
||||
if (!verify_and_update_potential_element_type(element_type)) return false;
|
||||
if (!CanInferTensorListElementType(push.output_handle(),
|
||||
initial_element_shape,
|
||||
potential_element_type))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>(
|
||||
use.getOwner())) {
|
||||
// For scatter op we can get the element shape by dropping the first
|
||||
// dimension of the input tensor.
|
||||
RankedTensorType element_type =
|
||||
DropFirstDimension(scatter.tensor().getType());
|
||||
if (!verify_and_update_potential_element_type(element_type)) return false;
|
||||
if (!CanInferTensorListElementType(scatter.output_handle(),
|
||||
initial_element_shape,
|
||||
potential_element_type))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) {
|
||||
auto element_type =
|
||||
set_item.item().getType().dyn_cast<RankedTensorType>();
|
||||
if (!verify_and_update_potential_element_type(element_type)) return false;
|
||||
if (!CanInferTensorListElementType(set_item.output_handle(),
|
||||
initial_element_shape,
|
||||
potential_element_type))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) {
|
||||
if (!CanInferTensorListElementType(pop.output_handle(),
|
||||
initial_element_shape,
|
||||
potential_element_type))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) {
|
||||
if (!CanInferTensorListElementType(resize.output_handle(),
|
||||
initial_element_shape,
|
||||
potential_element_type))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
// WhileRegionOp can explicitly capture TensorList value to be used inside
|
||||
// its regions. So we check the uses of corresponding block argument in each
|
||||
// region and the use of TensorList returned using YieldOp.
|
||||
if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) {
|
||||
for (auto branch : while_region.getRegions()) {
|
||||
if (!CanInferTensorListElementType(
|
||||
branch->getArgument(use.getOperandNumber()),
|
||||
initial_element_shape, potential_element_type))
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
|
||||
Operation* parent = yield->getParentOp();
|
||||
if (!CanInferTensorListElementType(
|
||||
parent->getResult(use.getOperandNumber()), initial_element_shape,
|
||||
potential_element_type))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
// Refining the tensor list element type might change the output of
|
||||
// TensorListElementShape which is expected to be the originally assigned
|
||||
// shape to TensorList init ops. So replace it with the original element
|
||||
// shape value.
|
||||
if (auto tl_element_shape =
|
||||
dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
|
||||
// If element types match, we can do a direct replacement.
|
||||
if (getElementTypeOrSelf(tl_element_shape.getResult()) ==
|
||||
getElementTypeOrSelf(initial_element_shape.getType())) {
|
||||
tl_element_shape.replaceAllUsesWith(initial_element_shape);
|
||||
} else {
|
||||
OpBuilder b(use.getOwner());
|
||||
auto cast_op = b.create<TF::CastOp>(
|
||||
use.getOwner()->getLoc(), tl_element_shape.getResult().getType(),
|
||||
initial_element_shape,
|
||||
/*truncate=*/b.getBoolAttr(false));
|
||||
tl_element_shape.replaceAllUsesWith(cast_op.getResult());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// Ignore ops that just consume a TensorList and do not output another
|
||||
// TensorList.
|
||||
if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op,
|
||||
TensorListLengthOp, TensorListGetItemOp>(use.getOwner()))
|
||||
continue;
|
||||
std::stack<Value> worklist;
|
||||
worklist.emplace(tensorlist);
|
||||
|
||||
// For any other unknown users of the TensorList, we are conservative and
|
||||
// stop element shape inference.
|
||||
return false;
|
||||
while (!worklist.empty()) {
|
||||
tensorlist = worklist.top();
|
||||
worklist.pop();
|
||||
|
||||
// TensorLists are semantically immutable. For example, TensorListSetItem
|
||||
// takes a TensorList as input and produces a TensorList as output. So to
|
||||
// traverse modifications to TensorList and verify that all elements written
|
||||
// to it have the same shape, we need to follow use-def chain of ops that
|
||||
// (conceptually) modify it i.e., ops that take an input TensorList and
|
||||
// produce an output TensorList.
|
||||
for (auto& use : tensorlist.getUses()) {
|
||||
if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) {
|
||||
auto element_type =
|
||||
push.tensor().getType().dyn_cast<RankedTensorType>();
|
||||
if (!verify_and_update_potential_element_type(element_type))
|
||||
return false;
|
||||
worklist.emplace(push.output_handle());
|
||||
continue;
|
||||
}
|
||||
if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>(
|
||||
use.getOwner())) {
|
||||
// For scatter op we can get the element shape by dropping the first
|
||||
// dimension of the input tensor.
|
||||
RankedTensorType element_type =
|
||||
DropFirstDimension(scatter.tensor().getType());
|
||||
if (!verify_and_update_potential_element_type(element_type))
|
||||
return false;
|
||||
worklist.emplace(scatter.output_handle());
|
||||
continue;
|
||||
}
|
||||
if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) {
|
||||
auto element_type =
|
||||
set_item.item().getType().dyn_cast<RankedTensorType>();
|
||||
DCOMMENT("\tTensorListSetItemOp " << element_type);
|
||||
if (!verify_and_update_potential_element_type(element_type))
|
||||
return false;
|
||||
worklist.emplace(set_item.output_handle());
|
||||
continue;
|
||||
}
|
||||
if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) {
|
||||
worklist.emplace(pop.output_handle());
|
||||
continue;
|
||||
}
|
||||
if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) {
|
||||
worklist.emplace(resize.output_handle());
|
||||
continue;
|
||||
}
|
||||
// WhileRegionOp can explicitly capture TensorList value to be used inside
|
||||
// its regions. So we check the uses of corresponding block argument in
|
||||
// each region and the use of TensorList returned using YieldOp.
|
||||
if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) {
|
||||
DCOMMENT("\tTL WhileRegion");
|
||||
for (auto branch : while_region.getRegions())
|
||||
worklist.emplace(branch->getArgument(use.getOperandNumber()));
|
||||
continue;
|
||||
}
|
||||
if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
|
||||
Operation* parent = yield->getParentOp();
|
||||
worklist.emplace(parent->getResult(use.getOperandNumber()));
|
||||
continue;
|
||||
}
|
||||
// TODO(jpienaar): This can be generalized.
|
||||
if (isa<IdentityOp, IdentityNOp, StopGradientOp>(use.getOwner())) {
|
||||
worklist.emplace(use.getOwner()->getResult(use.getOperandNumber()));
|
||||
continue;
|
||||
}
|
||||
// Refining the tensor list element type might change the output of
|
||||
// TensorListElementShape which is expected to be the originally assigned
|
||||
// shape to TensorList init ops. So replace it with the original element
|
||||
// shape value.
|
||||
if (auto tl_element_shape =
|
||||
dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
|
||||
// If element types match, we can do a direct replacement.
|
||||
if (getElementTypeOrSelf(tl_element_shape.getResult()) ==
|
||||
getElementTypeOrSelf(initial_element_shape.getType())) {
|
||||
tl_element_shape.replaceAllUsesWith(initial_element_shape);
|
||||
} else {
|
||||
OpBuilder b(use.getOwner());
|
||||
auto cast_op = b.create<TF::CastOp>(
|
||||
use.getOwner()->getLoc(), tl_element_shape.getResult().getType(),
|
||||
initial_element_shape,
|
||||
/*truncate=*/b.getBoolAttr(false));
|
||||
tl_element_shape.replaceAllUsesWith(cast_op.getResult());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// Ignore ops that just consume a TensorList and do not output another
|
||||
// TensorList.
|
||||
if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op,
|
||||
TensorListLengthOp, TensorListGetItemOp>(use.getOwner()))
|
||||
continue;
|
||||
|
||||
// For any other unknown users of the TensorList, we are conservative and
|
||||
// stop element shape inference.
|
||||
DCOMMENT("TensorListType infer, unknown op " << *use.getOwner());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -778,8 +781,12 @@ bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
|
||||
if (!element_type || !element_type.hasStaticShape()) return false;
|
||||
}
|
||||
if (!CanInferTensorListElementType(handle, initial_element_shape,
|
||||
&element_type))
|
||||
&element_type)) {
|
||||
DCOMMENT("InferShapeForListInitOps " << op << " could not infer");
|
||||
return false;
|
||||
}
|
||||
DCOMMENT("InferShapeForListInitOps " << op << " could be inferred "
|
||||
<< element_type);
|
||||
if (!element_type || !element_type.hasStaticShape()) return false;
|
||||
auto variant_type = VariantType::get(element_type, op->getContext());
|
||||
auto tensor_type = RankedTensorType::get({}, variant_type);
|
||||
@ -1049,7 +1056,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
// The shape function of these ops sometimes does not propagate subtypes
|
||||
// (handle shapes) for resource and variant types. We use a simple passthrough
|
||||
// to make sure they are preserved in the output.
|
||||
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp>(op)) {
|
||||
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp, TF::ZerosLikeOp>(
|
||||
op)) {
|
||||
return RefineTypeForPassThroughOperands(op, op->getOperands(),
|
||||
op->getResults());
|
||||
}
|
||||
|
@ -204,7 +204,7 @@ LogicalResult HandleWhileOp(
|
||||
}
|
||||
auto new_while =
|
||||
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
|
||||
new_while_operands, while_op.getAttrs());
|
||||
new_while_operands, while_op->getAttrs());
|
||||
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
|
||||
if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
@ -257,7 +257,7 @@ LogicalResult HandleIfOp(
|
||||
}
|
||||
auto new_if = OpBuilder(if_op).create<TF::IfOp>(
|
||||
if_op.getLoc(), then_func.getType().getResults(), new_if_operands,
|
||||
if_op.getAttrs());
|
||||
if_op->getAttrs());
|
||||
for (auto result : if_op.getResults()) {
|
||||
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
@ -306,7 +306,7 @@ LogicalResult HandlePartitionedCallOp(
|
||||
OpBuilder builder(call);
|
||||
auto new_call = builder.create<CallOp>(
|
||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||
new_operands, call.getAttrs());
|
||||
new_operands, call->getAttrs());
|
||||
new_call->setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||
|
@ -625,7 +625,7 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
|
||||
OpBuilder builder(while_op);
|
||||
auto new_while =
|
||||
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
|
||||
operands, while_op.getAttrs());
|
||||
operands, while_op->getAttrs());
|
||||
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
|
||||
if (ta_arg_buffer_type(i)) {
|
||||
while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
|
||||
@ -692,7 +692,7 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
|
||||
OpBuilder builder(if_op);
|
||||
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
|
||||
then_branch.getType().getResults(),
|
||||
operands, if_op.getAttrs());
|
||||
operands, if_op->getAttrs());
|
||||
auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
|
||||
auto retval = f.front().getTerminator()->getOperand(ret_ind);
|
||||
auto arg = retval.dyn_cast<BlockArgument>();
|
||||
@ -751,7 +751,7 @@ LogicalResult HandlePartitionedCallOp(
|
||||
OpBuilder builder(call);
|
||||
auto new_call = builder.create<CallOp>(
|
||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||
new_operands, call.getAttrs());
|
||||
new_operands, call->getAttrs());
|
||||
new_call->setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||
|
@ -208,7 +208,7 @@ LogicalResult HandleWhileOp(
|
||||
}
|
||||
auto new_while =
|
||||
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
|
||||
new_while_operands, while_op.getAttrs());
|
||||
new_while_operands, while_op->getAttrs());
|
||||
for (const auto& entry : output_buffer_to_size) {
|
||||
(*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
|
||||
new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
|
||||
@ -268,7 +268,7 @@ LogicalResult HandleCaseOrIfOp(
|
||||
FuncOp first_branch = branches.front();
|
||||
auto new_op = OpBuilder(op).create<CaseOrIfOp>(
|
||||
op.getLoc(), first_branch.getType().getResults(), new_operands,
|
||||
op.getAttrs());
|
||||
op->getAttrs());
|
||||
for (const auto& entry : output_buffer_to_size) {
|
||||
(*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
|
||||
new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
|
||||
@ -329,7 +329,7 @@ LogicalResult HandleWhileRegionOp(
|
||||
}
|
||||
auto new_while = builder.create<TF::WhileRegionOp>(
|
||||
while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(),
|
||||
new_while_operands, while_op.getAttrs());
|
||||
new_while_operands, while_op->getAttrs());
|
||||
new_while.body().takeBody(body_region);
|
||||
new_while.cond().takeBody(cond_region);
|
||||
for (const auto& entry : output_buffer_to_size) {
|
||||
@ -369,7 +369,7 @@ LogicalResult HandleIfRegionOp(
|
||||
// Recreate the op.
|
||||
auto new_op = OpBuilder(if_op).create<TF::IfRegionOp>(
|
||||
if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(),
|
||||
if_op.getOperand(), if_op.getAttrs());
|
||||
if_op.getOperand(), if_op->getAttrs());
|
||||
for (const auto& entry : output_buffer_to_size) {
|
||||
(*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
|
||||
new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
|
||||
@ -415,7 +415,7 @@ LogicalResult HandleCaseRegionOp(
|
||||
auto new_op = OpBuilder(case_op).create<TF::CaseRegionOp>(
|
||||
case_op.getLoc(),
|
||||
first_branch->front().getTerminator()->getOperandTypes(),
|
||||
case_op.getOperand(), case_op.getAttrs(), case_op.getNumRegions());
|
||||
case_op.getOperand(), case_op->getAttrs(), case_op.getNumRegions());
|
||||
for (const auto& entry : output_buffer_to_size) {
|
||||
(*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
|
||||
new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
|
||||
@ -456,7 +456,7 @@ LogicalResult HandlePartitionedCallOp(
|
||||
OpBuilder builder(call);
|
||||
auto new_call = builder.create<CallOp>(
|
||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||
new_operands, call.getAttrs());
|
||||
new_operands, call->getAttrs());
|
||||
new_call->setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||
@ -818,7 +818,7 @@ LogicalResult DecomposeTensorListOpsInternal(
|
||||
decomposed_partitioned_call_callees) {
|
||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||
// TODO(yuanzx): Add a pass to remove identities in device computation.
|
||||
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
|
||||
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp>(&op)) {
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
op.erase();
|
||||
} else if (auto list = llvm::dyn_cast<TF::EmptyTensorListOp>(&op)) {
|
||||
|
@ -761,3 +761,12 @@ func @cluster_oplist(%arg0: tensor<f32>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
];
|
||||
}
|
||||
|
||||
def VerifySuitableForExportPass : Pass<"tf-verify-for-export", "ModuleOp"> {
|
||||
let summary = "Verify module is suitable for export back to TF Graph";
|
||||
let description = [{
|
||||
Verifies whether all functions in module are of single tf_executor.graph and
|
||||
each tf_executor.island in tf_executor.graph only has a single op.
|
||||
}];
|
||||
|
||||
let constructor = "TF::CreateVerifySuitableForExportPass()";
|
||||
}
|
||||
|
@ -582,7 +582,7 @@ LogicalResult FormClustersInBlock(
|
||||
cluster->setAttrs(
|
||||
cluster_metadata->second.getDictionary(cluster.getContext()));
|
||||
// Exclude `num_replicas` as cluster should be replicated if necessary.
|
||||
cluster.removeAttr(kNumReplicasAttr);
|
||||
cluster->removeAttr(kNumReplicasAttr);
|
||||
}
|
||||
|
||||
return success();
|
||||
|
@ -225,24 +225,26 @@ void PropagateDevicesToResults(
|
||||
}
|
||||
|
||||
struct TPUDevicePropagation
|
||||
: public PassWrapper<TPUDevicePropagation, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
: public PassWrapper<TPUDevicePropagation, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void TPUDevicePropagation::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
if (!IsSupportedGraph(func)) return;
|
||||
void TPUDevicePropagation::runOnOperation() {
|
||||
ModuleOp m = getOperation();
|
||||
m.walk([&](FuncOp func) {
|
||||
if (!IsSupportedGraph(func)) return;
|
||||
|
||||
llvm::DenseMap<Value, llvm::StringRef> value_to_device;
|
||||
PropagateDevicesFromArguments(func, value_to_device);
|
||||
auto graph = llvm::cast<tf_executor::GraphOp>(func.front().front());
|
||||
PropagateDevicesInGraph(graph, value_to_device);
|
||||
PropagateDevicesToResults(func, graph.GetFetch(), value_to_device);
|
||||
llvm::DenseMap<Value, llvm::StringRef> value_to_device;
|
||||
PropagateDevicesFromArguments(func, value_to_device);
|
||||
auto graph = llvm::cast<tf_executor::GraphOp>(func.front().front());
|
||||
PropagateDevicesInGraph(graph, value_to_device);
|
||||
PropagateDevicesToResults(func, graph.GetFetch(), value_to_device);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDevicePropagationPass() {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDevicePropagationPass() {
|
||||
return std::make_unique<TPUDevicePropagation>();
|
||||
}
|
||||
|
||||
|
@ -326,7 +326,7 @@ tf_device::ClusterOp UpdateClusterResults(
|
||||
|
||||
auto new_cluster = builder->create<tf_device::ClusterOp>(
|
||||
cluster.getLoc(), new_cluster_result_types,
|
||||
/*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
|
||||
/*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
|
||||
new_cluster.body().takeBody(cluster.body());
|
||||
|
||||
auto operand_not_in_cluster = [&](OpOperand& operand) {
|
||||
@ -400,7 +400,7 @@ void RemoveClusterAliasedOutputs(OpBuilder* builder,
|
||||
builder->setInsertionPoint(cluster);
|
||||
auto new_cluster = builder->create<tf_device::ClusterOp>(
|
||||
cluster.getLoc(), new_cluster_result_types,
|
||||
/*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
|
||||
/*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
|
||||
new_cluster.body().takeBody(cluster.body());
|
||||
new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
|
||||
|
||||
|
@ -94,13 +94,13 @@ LogicalResult ReorderReplicateAndPartitionedInputs(
|
||||
for (const auto& operands_per_replica : operands_per_replica_per_core) {
|
||||
auto replicate_op = builder.create<TF::TPUReplicatedInputOp>(
|
||||
replicated_input.getLoc(), replicated_input.getType(),
|
||||
operands_per_replica, replicated_input.getAttrs());
|
||||
operands_per_replica, replicated_input->getAttrs());
|
||||
operands_per_core.push_back(replicate_op);
|
||||
}
|
||||
|
||||
auto pi = builder.create<TF::TPUPartitionedInputOp>(
|
||||
first_partitioned_input.getLoc(), replicated_input.getType(),
|
||||
operands_per_core, first_partitioned_input.getAttrs());
|
||||
operands_per_core, first_partitioned_input->getAttrs());
|
||||
replicated_input.replaceAllUsesWith(pi.output());
|
||||
return success();
|
||||
}
|
||||
|
@ -0,0 +1,43 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
class VerifySuitableForExportPass
|
||||
: public VerifySuitableForExportPassBase<VerifySuitableForExportPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
if (failed(tensorflow::VerifyExportSuitable(getOperation())))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateVerifySuitableForExportPass() {
|
||||
return std::make_unique<VerifySuitableForExportPass>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
@ -46,8 +47,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -74,9 +77,6 @@ using stream_executor::port::StatusOr;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kInvalidExecutorGraphMsg[] =
|
||||
"Functions must be of a single Graph with single op Islands: ";
|
||||
|
||||
constexpr char kDeviceAttr[] = "tf.device";
|
||||
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
|
||||
|
||||
@ -91,52 +91,6 @@ class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
|
||||
}
|
||||
};
|
||||
|
||||
// Checks functions in module are of single tf_executor.graph and each
|
||||
// tf_executor.island in tf_executor.graph only has a single op.
|
||||
Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) {
|
||||
Status status = Status::OK();
|
||||
module.walk([&](mlir::FuncOp function) {
|
||||
if (!llvm::hasSingleElement(function)) {
|
||||
status = errors::FailedPrecondition(
|
||||
kInvalidExecutorGraphMsg,
|
||||
"only single block functions are supported.");
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
|
||||
auto block = function.front().without_terminator();
|
||||
auto graph = llvm::dyn_cast<mlir::tf_executor::GraphOp>(block.begin());
|
||||
if (!graph) {
|
||||
status = errors::FailedPrecondition(
|
||||
kInvalidExecutorGraphMsg,
|
||||
"first op in function is not a tf_executor.graph.");
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
|
||||
if (!hasSingleElement(block)) {
|
||||
status = errors::FailedPrecondition(
|
||||
kInvalidExecutorGraphMsg,
|
||||
"function does not only contain a single tf_executor.graph.");
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
|
||||
for (Operation& op : graph.GetBody()) {
|
||||
auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
|
||||
if (!island) continue;
|
||||
|
||||
if (!island.WrapsSingleOp()) {
|
||||
status = errors::FailedPrecondition(
|
||||
kInvalidExecutorGraphMsg,
|
||||
"tf_executor.island must perfectly wrap a single op.");
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
return mlir::WalkResult::advance();
|
||||
});
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is
|
||||
// returned.
|
||||
Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) {
|
||||
@ -159,10 +113,10 @@ class Exporter {
|
||||
|
||||
// Converts a given FuncOp to a FunctionDef and adds it to the function
|
||||
// definition library
|
||||
static Status ConvertLibFunction(const GraphExportConfig& configs,
|
||||
const Dialect* tf_dialect,
|
||||
mlir::FuncOp function,
|
||||
FunctionDefLibrary* flib);
|
||||
static Status ConvertLibFunction(
|
||||
const GraphExportConfig& configs, const Dialect* tf_dialect,
|
||||
mlir::FuncOp function, FunctionDefLibrary* flib,
|
||||
llvm::SmallDenseSet<mlir::FuncOp>& visited_functions);
|
||||
// Converts the given FuncOp to a Graph. The arguments and returns of
|
||||
// function are added to the graph with special op names kArgOp and kRetOp.
|
||||
// Later on, this graph can be converted a function definition and added to
|
||||
@ -170,6 +124,7 @@ class Exporter {
|
||||
static StatusOr<std::unique_ptr<Graph>> Convert(
|
||||
const GraphExportConfig& configs, const Dialect* tf_dialect,
|
||||
mlir::FuncOp function, FunctionDefLibrary* flib,
|
||||
llvm::SmallDenseSet<mlir::FuncOp>& visited_functions,
|
||||
absl::flat_hash_set<Node*>* control_ret_nodes);
|
||||
|
||||
private:
|
||||
@ -451,6 +406,7 @@ Status Exporter::GetControlRetNodes(
|
||||
StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
const GraphExportConfig& configs, const Dialect* tf_dialect,
|
||||
mlir::FuncOp function, FunctionDefLibrary* flib,
|
||||
llvm::SmallDenseSet<mlir::FuncOp>& visited_functions,
|
||||
absl::flat_hash_set<Node*>* control_ret_nodes) {
|
||||
mlir::Block& block = function.front();
|
||||
|
||||
@ -550,7 +506,8 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
function->getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
|
||||
name);
|
||||
if (func != nullptr) {
|
||||
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
|
||||
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib,
|
||||
visited_functions));
|
||||
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
|
||||
}
|
||||
return Status::OK();
|
||||
@ -610,22 +567,21 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
return graph;
|
||||
}
|
||||
|
||||
Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
|
||||
const Dialect* tf_dialect,
|
||||
mlir::FuncOp function,
|
||||
FunctionDefLibrary* flib) {
|
||||
// First look for the function in the current function library. If found,
|
||||
// nothing needs to be done.
|
||||
OpRegistry empty_registry;
|
||||
FunctionLibraryDefinition flib_def(&empty_registry, *flib);
|
||||
Status Exporter::ConvertLibFunction(
|
||||
const GraphExportConfig& configs, const Dialect* tf_dialect,
|
||||
mlir::FuncOp function, FunctionDefLibrary* flib,
|
||||
llvm::SmallDenseSet<mlir::FuncOp>& visited_functions) {
|
||||
// Return early if the function has already been exported.
|
||||
bool is_new_function = visited_functions.insert(function).second;
|
||||
if (!is_new_function) return Status::OK();
|
||||
|
||||
auto function_name = function.getName().str();
|
||||
if (flib_def.Find(function_name)) return Status::OK();
|
||||
|
||||
// TODO(fengliuai): use a small flib_def to reduce overhead
|
||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||
TF_ASSIGN_OR_RETURN(auto sub_graph,
|
||||
Exporter::Convert(configs, tf_dialect, function, flib,
|
||||
&control_ret_nodes));
|
||||
visited_functions, &control_ret_nodes));
|
||||
const auto control_ret = [&](const Node* n) -> absl::optional<string> {
|
||||
return control_ret_nodes.contains(n)
|
||||
? absl::make_optional<string>(n->name())
|
||||
@ -652,8 +608,8 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
|
||||
auto grad_func =
|
||||
function->getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
|
||||
attr.getValue());
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertLibFunction(configs, tf_dialect, grad_func, flib));
|
||||
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, grad_func, flib,
|
||||
visited_functions));
|
||||
GradientDef grad;
|
||||
grad.set_function_name(function_name);
|
||||
grad.set_gradient_func(grad_func.getName().str());
|
||||
@ -709,26 +665,32 @@ Status Exporter::Convert(mlir::ModuleOp module,
|
||||
mlir::Identifier::get("main", module.getContext());
|
||||
absl::optional<mlir::FuncOp> entry_func;
|
||||
FunctionDefLibrary flib;
|
||||
llvm::SmallDenseSet<mlir::FuncOp> visited_functions;
|
||||
auto tf_dialect = module.getContext()->getLoadedDialect("tf");
|
||||
for (auto function : module.getOps<mlir::FuncOp>()) {
|
||||
if (function.isExternal())
|
||||
return errors::FailedPrecondition("External functions not supported");
|
||||
|
||||
if (function.getName() == entry_func_id) {
|
||||
if (function.getName() == entry_func_id &&
|
||||
!configs.export_entry_func_to_flib) {
|
||||
entry_func.emplace(function);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertLibFunction(configs, tf_dialect, function, &flib));
|
||||
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, function,
|
||||
&flib, visited_functions));
|
||||
}
|
||||
}
|
||||
|
||||
if (!entry_func.has_value())
|
||||
return errors::FailedPrecondition("entry function `main` must be present");
|
||||
if (!configs.export_entry_func_to_flib) {
|
||||
if (!entry_func.has_value())
|
||||
return errors::FailedPrecondition(
|
||||
"entry function `main` must be present");
|
||||
|
||||
// Updates the graph and the function library definition.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*graph, Exporter::Convert(configs, tf_dialect, entry_func.value(),
|
||||
&flib, visited_functions, control_ret_nodes));
|
||||
}
|
||||
|
||||
// Updates the graph and the function library definition.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib,
|
||||
control_ret_nodes));
|
||||
for (auto& func_def : flib.function()) {
|
||||
TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
|
||||
}
|
||||
@ -744,8 +706,10 @@ Status ConvertMlirToGraph(mlir::ModuleOp module,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def,
|
||||
absl::flat_hash_set<Node*>* control_ret_nodes) {
|
||||
TF_RETURN_IF_ERROR(HasSingleGraphSingleOpIslandsFunctions(module));
|
||||
return Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes);
|
||||
mlir::StatusScopedDiagnosticHandler sh(module.getContext());
|
||||
if (failed(VerifyExportSuitable(module))) return sh.ConsumeStatus();
|
||||
return sh.Combine(
|
||||
Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes));
|
||||
}
|
||||
|
||||
Status ConvertMlirToGraph(mlir::ModuleOp module,
|
||||
@ -761,8 +725,16 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
||||
mlir::ModuleOp module, const GraphExportConfig& configs) {
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
|
||||
FunctionDefLibrary());
|
||||
auto graph = absl::make_unique<Graph>(flib_def);
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
|
||||
|
||||
// If the entry function is exported to flib, then no graph is constructed.
|
||||
// Construct one in that case.
|
||||
if (configs.export_entry_func_to_flib) {
|
||||
graph = std::make_unique<Graph>(OpRegistry::Global());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib_def.ToProto()));
|
||||
|
||||
auto graphdef = absl::make_unique<GraphDef>();
|
||||
graph->ToGraphDef(graphdef.get());
|
||||
if (!configs.export_library) graphdef->clear_library();
|
||||
@ -784,8 +756,9 @@ stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
|
||||
FunctionDef* function_def) {
|
||||
Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
|
||||
FunctionDefLibrary flib;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib));
|
||||
llvm::SmallDenseSet<mlir::FuncOp> visited_functions;
|
||||
TF_RETURN_IF_ERROR(Exporter::ConvertLibFunction(configs, tf_dialect, func,
|
||||
&flib, visited_functions));
|
||||
for (auto& func_def : flib.function()) {
|
||||
if (func_def.signature().name() == func.getName()) {
|
||||
*function_def = func_def;
|
||||
|
@ -113,6 +113,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/protobuf/saver.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
@ -1291,17 +1292,23 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
||||
if (name_and_value.first == "_input_shapes") {
|
||||
auto& list = name_and_value.second.list();
|
||||
auto& signature = func_def->signature();
|
||||
if (list.shape_size() != signature.input_arg_size()) {
|
||||
// Some models have "_input_shapes" attribute, but with its value empty
|
||||
if (list.shape_size() > 0 &&
|
||||
list.shape_size() != signature.input_arg_size()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Number of input arguments must be equal to the length of "
|
||||
"_input_shapes attribute in function '",
|
||||
StringRefToView(func_name), "'.");
|
||||
}
|
||||
for (int i = 0; i < list.shape_size(); i++) {
|
||||
for (int i = 0; i < signature.input_arg_size(); i++) {
|
||||
auto& input_arg = signature.input_arg(i);
|
||||
auto& array_info = specs.inputs[input_arg.name()];
|
||||
array_info.imported_dtype = input_arg.type();
|
||||
array_info.shape = list.shape(i);
|
||||
// set to unranked for empty "_input_shapes" attribute
|
||||
if (list.shape_size() > 0)
|
||||
array_info.shape = list.shape(i);
|
||||
else
|
||||
array_info.shape.set_unknown_rank(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1661,7 +1668,7 @@ mlir::Location ImporterBase::GetLocation(const Node& node) {
|
||||
|
||||
// If there are no locations in the stack trace, fall back to just a
|
||||
// NameLoc with no child.
|
||||
if (locations.empty()) return mlir::NameLoc::get(name_loc_id, context_);
|
||||
if (locations.empty()) return mlir::NameLoc::get(name_loc_id);
|
||||
|
||||
// Use the front FileLineColLoc to generate a NameLoc.
|
||||
mlir::Location node_name_loc =
|
||||
@ -1984,8 +1991,28 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
||||
}
|
||||
|
||||
const auto& node_def = node.def();
|
||||
// NodeDef can contain partial TF device names. In such cases, canonicalize
|
||||
// it. Note that in current TF, placer will place full device name to each
|
||||
// node.
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!DeviceNameUtils::ParseFullName(node_def.device(), &parsed_name)) {
|
||||
return errors::InvalidArgument(
|
||||
"Op ", op_name, " has invalid device name: ", node_def.device());
|
||||
}
|
||||
// Keep the parsed name untouched if the device name is empty.
|
||||
if (!node_def.device().empty()) {
|
||||
if (!parsed_name.has_type) {
|
||||
parsed_name.type = "CPU";
|
||||
parsed_name.has_type = true;
|
||||
}
|
||||
if (!parsed_name.has_id) {
|
||||
parsed_name.id = 0;
|
||||
parsed_name.has_id = true;
|
||||
}
|
||||
}
|
||||
result.attributes.push_back(builder_.getNamedAttr(
|
||||
"device", builder_.getStringAttr(std::string(node_def.device()))));
|
||||
"device", builder_.getStringAttr(
|
||||
DeviceNameUtils::ParsedNameToString(parsed_name))));
|
||||
|
||||
// Map user function calls to LegacyCall ops and add the user function name
|
||||
// as an attribute.
|
||||
@ -2165,7 +2192,8 @@ class GraphDefImporter : public ImporterBase {
|
||||
mlir::MLIRContext* context, const Graph& graph,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
|
||||
llvm::StringRef func_name);
|
||||
llvm::StringRef func_name,
|
||||
std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
|
||||
|
||||
private:
|
||||
explicit GraphDefImporter(
|
||||
@ -2206,11 +2234,11 @@ class GraphDefImporter : public ImporterBase {
|
||||
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
mlir::MLIRContext* context, const Graph& graph,
|
||||
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
|
||||
const GraphImportConfig& specs, llvm::StringRef func_name) {
|
||||
const GraphImportConfig& specs, llvm::StringRef func_name,
|
||||
std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
|
||||
LoadImporterDialects(*context);
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
NameUniquifier function_name_uniquifier(flib_def);
|
||||
|
||||
GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
|
||||
@ -3499,12 +3527,14 @@ class SavedModelSignatureDefImporterLite {
|
||||
const std::string& name,
|
||||
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
|
||||
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
|
||||
const std::vector<std::string> control_outputs);
|
||||
const std::vector<std::string> control_outputs,
|
||||
std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
|
||||
|
||||
// Moves the functions in `sub_module` to `module_` and skips the duplicate
|
||||
// functions.
|
||||
Status MoveConvertedFunctionsToModule(absl::string_view name,
|
||||
mlir::ModuleOp sub_module);
|
||||
Status MoveConvertedFunctionsToModule(
|
||||
absl::string_view name, mlir::ModuleOp sub_module,
|
||||
const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name);
|
||||
|
||||
GraphImportConfig::InputArrays ParseInputArrays(
|
||||
llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
|
||||
@ -3545,15 +3575,25 @@ SavedModelSignatureDefImporterLite::ConvertAssets() {
|
||||
}
|
||||
|
||||
Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
|
||||
absl::string_view name, mlir::ModuleOp sub_module) {
|
||||
absl::string_view name, mlir::ModuleOp sub_module,
|
||||
const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
|
||||
mlir::Builder builder(sub_module.getContext());
|
||||
mlir::SymbolTable sub_module_symbol_table(sub_module);
|
||||
|
||||
// Functions originally from graphdef library might have a different name
|
||||
// after conversion, we build the set of the converted names
|
||||
absl::flat_hash_set<std::string> original_func_mlir_names;
|
||||
for (const auto& kv : tf_name_to_mlir_name)
|
||||
original_func_mlir_names.insert(kv.second);
|
||||
|
||||
// Prefix private functions with the unique signature name, so that it cannot
|
||||
// collide with private functions used in the other signatures.
|
||||
for (auto func : sub_module.getOps<mlir::FuncOp>()) {
|
||||
if (mlir::tf_saved_model::IsExported(func)) continue;
|
||||
|
||||
// Skip the original functions from graphdef library
|
||||
if (original_func_mlir_names.count(func.sym_name().str())) continue;
|
||||
|
||||
std::string new_sym_name = absl::StrCat(name, "/", func.sym_name().str());
|
||||
if (mlir::failed(sub_module_symbol_table.replaceAllSymbolUses(
|
||||
func, new_sym_name, sub_module)))
|
||||
@ -3567,7 +3607,7 @@ Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
|
||||
|
||||
// Copy all functions used by this signature to the final MLIR module.
|
||||
for (auto func : sub_module.getOps<mlir::FuncOp>()) {
|
||||
DCHECK(symbol_table_.lookup(func.sym_name()) == nullptr);
|
||||
// The insert here is a NO-OP if the function already exists.
|
||||
symbol_table_.insert(func.clone());
|
||||
}
|
||||
|
||||
@ -3586,13 +3626,15 @@ Status SavedModelSignatureDefImporterLite::ConvertInitializer(
|
||||
inputs.push_back({asset.tensor_name, tensor_info});
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto sub_module, ConvertGraph(target_node_name, inputs,
|
||||
{}, {target_node_name}));
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
TF_ASSIGN_OR_RETURN(auto sub_module,
|
||||
ConvertGraph(target_node_name, inputs, {},
|
||||
{target_node_name}, tf_name_to_mlir_name));
|
||||
|
||||
mlir::SymbolTable sub_symbol_table(*sub_module);
|
||||
|
||||
auto init_func_op = sub_symbol_table.lookup<mlir::FuncOp>(target_node_name);
|
||||
init_func_op.removeAttr("tf.entry_function");
|
||||
init_func_op->removeAttr("tf.entry_function");
|
||||
|
||||
mlir::OpBuilder builder(module_->getBodyRegion());
|
||||
|
||||
@ -3612,7 +3654,8 @@ Status SavedModelSignatureDefImporterLite::ConvertInitializer(
|
||||
"__tf_saved_model_session_initializer_", target_node_name)}));
|
||||
|
||||
// Move the converted functions to top level MLIR module.
|
||||
return MoveConvertedFunctionsToModule(target_node_name, *sub_module);
|
||||
return MoveConvertedFunctionsToModule(target_node_name, *sub_module,
|
||||
tf_name_to_mlir_name);
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef>
|
||||
@ -3620,7 +3663,8 @@ SavedModelSignatureDefImporterLite::ConvertGraph(
|
||||
const std::string& name,
|
||||
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
|
||||
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
|
||||
const std::vector<std::string> control_outputs) {
|
||||
const std::vector<std::string> control_outputs,
|
||||
std::unordered_map<std::string, std::string>& tf_name_to_mlir_name) {
|
||||
VLOG(1) << "Importing Signature: " << name;
|
||||
|
||||
GraphImportConfig specs;
|
||||
@ -3634,7 +3678,7 @@ SavedModelSignatureDefImporterLite::ConvertGraph(
|
||||
// Convert sub-graph to MLIR module.
|
||||
return GraphDefImporter::Convert(module_->getContext(), *subgraph,
|
||||
input_.debug_info(), subgraph->flib_def(),
|
||||
specs, name);
|
||||
specs, name, tf_name_to_mlir_name);
|
||||
}
|
||||
|
||||
Status SavedModelSignatureDefImporterLite::ConvertSignature(
|
||||
@ -3654,9 +3698,12 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature(
|
||||
return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
|
||||
});
|
||||
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
|
||||
// Convert sub-graph to MLIR module.
|
||||
TF_ASSIGN_OR_RETURN(auto sub_module,
|
||||
ConvertGraph(sig_def_key, inputs, outputs, {}));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sub_module,
|
||||
ConvertGraph(sig_def_key, inputs, outputs, {}, tf_name_to_mlir_name));
|
||||
mlir::OpBuilder builder(sub_module->getBodyRegion());
|
||||
|
||||
// Find the FuncOp which corresponds to current SignatureDef.
|
||||
@ -3682,7 +3729,8 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature(
|
||||
}
|
||||
|
||||
// Move the converted functions to top level MLIR module.
|
||||
return MoveConvertedFunctionsToModule(sig_def_key, *sub_module);
|
||||
return MoveConvertedFunctionsToModule(sig_def_key, *sub_module,
|
||||
tf_name_to_mlir_name);
|
||||
}
|
||||
|
||||
GraphImportConfig::InputArrays
|
||||
@ -3820,7 +3868,7 @@ class SavedModelSignatureDefImporter {
|
||||
(*module)->setAttr("tf_saved_model.under_construction",
|
||||
builder.getUnitAttr());
|
||||
TF_RETURN_IF_ERROR(LiftVariables(bundle, *module));
|
||||
module->removeAttr("tf_saved_model.under_construction");
|
||||
(*module)->removeAttr("tf_saved_model.under_construction");
|
||||
|
||||
return module;
|
||||
}
|
||||
@ -3895,8 +3943,9 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
||||
const_cast<FunctionLibraryDefinition*>(&flib_def),
|
||||
specs.restrict_functionalization_to_tpu_nodes));
|
||||
}
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
|
||||
/*func_name=*/"main");
|
||||
/*func_name=*/"main", tf_name_to_mlir_name);
|
||||
}
|
||||
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
|
||||
@ -3908,9 +3957,10 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
|
||||
specs.graph_as_function = true;
|
||||
for (const auto* control_ret_node : fbody->control_ret_nodes)
|
||||
specs.control_outputs.push_back(control_ret_node->name());
|
||||
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
|
||||
flib_def, specs,
|
||||
fbody->fdef.signature().name());
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
return GraphDefImporter::Convert(
|
||||
context, *fbody->graph, dummy_debug_info, flib_def, specs,
|
||||
fbody->fdef.signature().name(), tf_name_to_mlir_name);
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
|
@ -75,6 +75,9 @@ struct GraphExportConfig {
|
||||
bool export_library = true;
|
||||
// Whether to export debug original node name in the GraphDef.
|
||||
bool export_debug_info = true;
|
||||
// Whether to export the entry function to function library instead of the
|
||||
// graph.
|
||||
bool export_entry_func_to_flib = false;
|
||||
};
|
||||
|
||||
// Parses the command line flag strings to the specification of nodes in
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
using llvm::cl::opt;
|
||||
|
||||
// Import options.
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> input_arrays(
|
||||
"tf-input-arrays", llvm::cl::desc("Input tensor names, separated by ','"),
|
||||
@ -115,3 +116,11 @@ opt<bool> enable_shape_inference(
|
||||
"tf-enable-shape-inference-on-import",
|
||||
llvm::cl::desc("Enable shape inference on import (temporary)"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// Export options.
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> export_entry_func_to_flib(
|
||||
"tf-export-entry-func-to-flib",
|
||||
llvm::cl::desc(
|
||||
"Export entry function to function library instead of graph"),
|
||||
llvm::cl::init(false));
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
|
||||
// Please see the implementation file for documentation of these options.
|
||||
|
||||
// Import options.
|
||||
extern llvm::cl::opt<std::string> input_arrays;
|
||||
extern llvm::cl::opt<std::string> input_dtypes;
|
||||
extern llvm::cl::opt<std::string> input_shapes;
|
||||
@ -42,4 +43,7 @@ extern llvm::cl::opt<bool> upgrade_legacy;
|
||||
// TODO(jpienaar): Temporary flag, flip default and remove.
|
||||
extern llvm::cl::opt<bool> enable_shape_inference;
|
||||
|
||||
// Export options.
|
||||
extern llvm::cl::opt<bool> export_entry_func_to_flib;
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_CL_H_
|
||||
|
@ -75,6 +75,7 @@ static LogicalResult MlirToGraphdefTranslateFunction(
|
||||
|
||||
// TODO(fengliuai): Add exporter flags.
|
||||
tensorflow::GraphExportConfig confs;
|
||||
confs.export_entry_func_to_flib = export_entry_func_to_flib;
|
||||
StatusOr<std::unique_ptr<tensorflow::GraphDef>> graphdef_or(
|
||||
tensorflow::ConvertMlirToGraphdef(module, confs));
|
||||
if (!graphdef_or.status().ok()) {
|
||||
|
@ -100,15 +100,15 @@ struct WritableFileRawStream : public llvm::raw_ostream {
|
||||
|
||||
struct CrashReproducerStream : public mlir::PassManager::ReproducerStream {
|
||||
CrashReproducerStream(llvm::StringRef name,
|
||||
std::unique_ptr<WritableFile> file)
|
||||
std::unique_ptr<llvm::raw_ostream> file)
|
||||
: name(name), ostream(std::move(file)) {}
|
||||
|
||||
llvm::StringRef description() override { return name; }
|
||||
raw_ostream& os() override { return ostream; }
|
||||
raw_ostream& os() override { return *ostream; }
|
||||
|
||||
private:
|
||||
std::string name;
|
||||
WritableFileRawStream ostream;
|
||||
std::unique_ptr<llvm::raw_ostream> ostream;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -225,25 +225,32 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) {
|
||||
}
|
||||
}
|
||||
|
||||
auto* env = tensorflow::Env::Default();
|
||||
auto status = env->RecursivelyCreateDir(path);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "cannot create directory '" + path +
|
||||
"': " + status.error_message();
|
||||
return;
|
||||
}
|
||||
if (path != "-") {
|
||||
auto* env = tensorflow::Env::Default();
|
||||
auto status = env->RecursivelyCreateDir(path);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "cannot create directory '" + path +
|
||||
"': " + status.error_message();
|
||||
return;
|
||||
}
|
||||
|
||||
path += "/mlir_reproducer_";
|
||||
path += "/mlir_reproducer_";
|
||||
|
||||
if (!tensorflow::Env::Default()->CreateUniqueFileName(&path, ".mlir")) {
|
||||
LOG(WARNING)
|
||||
<< "cannot create unique filename, won't enable MLIR crash reproducer.";
|
||||
return;
|
||||
if (!tensorflow::Env::Default()->CreateUniqueFileName(&path, ".mlir")) {
|
||||
LOG(WARNING) << "cannot create unique filename, won't enable MLIR crash "
|
||||
"reproducer.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
mlir::PassManager::ReproducerStreamFactory factory =
|
||||
[path](std::string& error)
|
||||
-> std::unique_ptr<mlir::PassManager::ReproducerStream> {
|
||||
// Use the stderr stream.
|
||||
if (path == "-")
|
||||
return std::make_unique<CrashReproducerStream>(
|
||||
"(stderr)", std::make_unique<LogInfoRawStream>());
|
||||
|
||||
// Try to open the file and generate a raw_ostream.
|
||||
std::unique_ptr<WritableFile> file;
|
||||
Status status = tensorflow::Env::Default()->NewWritableFile(path, &file);
|
||||
@ -252,7 +259,8 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) {
|
||||
"': ", status.error_message());
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_unique<CrashReproducerStream>(path, std::move(file));
|
||||
return std::make_unique<CrashReproducerStream>(
|
||||
path, std::make_unique<WritableFileRawStream>(std::move(file)));
|
||||
};
|
||||
pm.enableCrashReproducerGeneration(factory, /*genLocalReproducer=*/false);
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ LogicalResult InferReturnTypeComponentsForTFOp(
|
||||
OpResultAsShapeFn op_result_as_shape_fn,
|
||||
ResultElementTypeFn result_element_type_fn,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferred_return_shapes) {
|
||||
assert(op->getName().getDialect() ==
|
||||
assert(op->getName().getDialectNamespace() ==
|
||||
TensorFlowDialect::getDialectNamespace());
|
||||
|
||||
auto op_name_or =
|
||||
|
@ -0,0 +1,68 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
|
||||
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kInvalidExecutorGraphMsg[] =
|
||||
"functions must be of a single Graph with single op Islands: ";
|
||||
|
||||
} // namespace
|
||||
|
||||
mlir::LogicalResult VerifyExportSuitable(mlir::ModuleOp module) {
|
||||
mlir::WalkResult result = module.walk([&](mlir::FuncOp function) {
|
||||
if (!llvm::hasSingleElement(function)) {
|
||||
function.emitError(kInvalidExecutorGraphMsg)
|
||||
<< "only single block functions are supported";
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
|
||||
auto block = function.front().without_terminator();
|
||||
auto graph = llvm::dyn_cast<mlir::tf_executor::GraphOp>(block.begin());
|
||||
if (!graph) {
|
||||
block.begin()->emitError(kInvalidExecutorGraphMsg)
|
||||
<< "first op in function is not a tf_executor.graph";
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
|
||||
if (!hasSingleElement(block)) {
|
||||
function.emitError(kInvalidExecutorGraphMsg)
|
||||
<< "function does not only contain a single tf_executor.graph";
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
|
||||
for (mlir::Operation& op : graph.GetBody()) {
|
||||
auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
|
||||
if (!island) continue;
|
||||
|
||||
if (!island.WrapsSingleOp()) {
|
||||
island.emitError(kInvalidExecutorGraphMsg)
|
||||
<< "tf_executor.island must perfectly wrap a single op";
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
return mlir::WalkResult::advance();
|
||||
});
|
||||
|
||||
return mlir::failure(result.wasInterrupted());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns whether all functions in module are of single tf_executor.graph and
|
||||
// each tf_executor.island in tf_executor.graph only has a single op.
|
||||
mlir::LogicalResult VerifyExportSuitable(mlir::ModuleOp module);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_
|
@ -22,10 +22,10 @@ from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
from tensorflow.compiler.mlir.tfr.python.composite import Composite
|
||||
from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
|
||||
from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
@ -22,13 +22,13 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
from absl import app
|
||||
|
||||
from tensorflow.compiler.mlir.tfr.python import composite
|
||||
from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
|
||||
from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import gen_array_ops as array_ops
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags
|
||||
|
||||
|
||||
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
from absl import app
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@ -31,7 +32,6 @@ from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags
|
||||
|
||||
Composite = composite.Composite
|
||||
|
@ -23,13 +23,13 @@ from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.compiler.mlir.tfr.python import composite
|
||||
from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
|
||||
from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags
|
||||
|
||||
Composite = composite.Composite
|
||||
|
@ -441,7 +441,7 @@ def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">,
|
||||
// non-derived ones.
|
||||
llvm::StringSet<> getDefinedAttributeNames() {
|
||||
llvm::StringSet<> all_attrs;
|
||||
for (auto& attr : getAttrs()) {
|
||||
for (auto& attr : (*this)->getAttrs()) {
|
||||
all_attrs.insert(attr.first.strref());
|
||||
}
|
||||
for (const auto& operand : llvm::enumerate(getType().getInputs())) {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user