Merge branch 'master' into interface_16x8
This commit is contained in:
commit
fe806513ad
29
.bazelrc
29
.bazelrc
@ -356,7 +356,15 @@ build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/to
|
||||
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux
|
||||
build:rbe_linux_cuda_base --config=rbe_linux
|
||||
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
|
||||
build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10
|
||||
build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7
|
||||
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
|
||||
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
|
||||
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
|
||||
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
@ -365,13 +373,20 @@ build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda1
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
|
||||
build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_CUDA=1
|
||||
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
|
||||
test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
|
||||
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
|
||||
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
|
||||
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||
|
||||
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||
|
||||
|
@ -138,3 +138,7 @@ load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
|
||||
|
||||
bazel_version_repository(name = "bazel_version")
|
||||
|
||||
load("//third_party/googleapis:repository_rules.bzl", "config_googleapis")
|
||||
|
||||
config_googleapis()
|
||||
|
||||
|
@ -1155,7 +1155,7 @@ def set_trisycl_include_dir(environ_cp):
|
||||
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
|
||||
|
||||
|
||||
def system_specific_test_config(env):
|
||||
def system_specific_test_config(environ_cp):
|
||||
"""Add default build and test flags required for TF tests to bazelrc."""
|
||||
write_to_bazelrc('test --flaky_test_attempts=3')
|
||||
write_to_bazelrc('test --test_size_filters=small,medium')
|
||||
@ -1171,14 +1171,14 @@ def system_specific_test_config(env):
|
||||
test_only_filters = ['-oss_serial']
|
||||
if is_windows():
|
||||
test_and_build_filters.append('-no_windows')
|
||||
if env.get('TF_NEED_CUDA', None) == '1':
|
||||
if environ_cp.get('TF_NEED_CUDA', None) == '1':
|
||||
test_and_build_filters += ['-no_windows_gpu', '-no_gpu']
|
||||
else:
|
||||
test_and_build_filters.append('-gpu')
|
||||
elif is_macos():
|
||||
test_and_build_filters += ['-gpu', '-nomac', '-no_mac']
|
||||
elif is_linux():
|
||||
if env.get('TF_NEED_CUDA', None) == '1':
|
||||
if environ_cp.get('TF_NEED_CUDA', None) == '1':
|
||||
test_and_build_filters.append('-no_gpu')
|
||||
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
|
||||
else:
|
||||
@ -1522,7 +1522,7 @@ def main():
|
||||
create_android_ndk_rule(environ_cp)
|
||||
create_android_sdk_rule(environ_cp)
|
||||
|
||||
system_specific_test_config(os.environ)
|
||||
system_specific_test_config(environ_cp)
|
||||
|
||||
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
|
||||
if environ_cp.get('TF_CONFIGURE_IOS') == '1':
|
||||
|
@ -702,6 +702,7 @@ tf_cc_shared_object(
|
||||
"//tensorflow/c:exported_symbols.lds",
|
||||
"//tensorflow/c:version_script.lds",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
],
|
||||
|
@ -651,16 +651,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
TF_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
worker_session.get());
|
||||
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
|
||||
grpc_server->worker_env(), std::move(remote_eager_workers),
|
||||
added_workers, removed_workers, context_id, r, device_mgr,
|
||||
keep_alive_secs, cluster_flr));
|
||||
added_workers, removed_workers, context_id, r));
|
||||
}
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/dlpack.h"
|
||||
|
||||
#include "include/dlpack/dlpack.h" // TF:dlpack
|
||||
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -24,9 +24,14 @@ cc_library(
|
||||
srcs = [
|
||||
"modular_filesystem.cc",
|
||||
"modular_filesystem_registration.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"modular_filesystem.h",
|
||||
"modular_filesystem_registration.h",
|
||||
],
|
||||
hdrs = ["modular_filesystem.h"],
|
||||
# TODO(mihaimaruseac): Visibility should be more restrictive once we
|
||||
# convert to modular filesystems everywhere
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":filesystem_interface",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
|
@ -440,7 +440,25 @@ Status ModularWritableFile::Tell(int64* position) {
|
||||
}
|
||||
|
||||
Status RegisterFilesystemPlugin(const std::string& dso_path) {
|
||||
return filesystem_registration::RegisterFilesystemPluginImpl(dso_path);
|
||||
// Step 1: Load plugin
|
||||
Env* env = Env::Default();
|
||||
void* dso_handle;
|
||||
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
|
||||
|
||||
// Step 2: Load symbol for `TF_InitPlugin`
|
||||
void* dso_symbol;
|
||||
TF_RETURN_IF_ERROR(
|
||||
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
|
||||
|
||||
// Step 3: Call `TF_InitPlugin`
|
||||
TF_FilesystemPluginInfo info;
|
||||
memset(&info, 0, sizeof(info));
|
||||
auto TF_InitPlugin =
|
||||
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
|
||||
TF_InitPlugin(&info);
|
||||
|
||||
// Step 4: Do the actual registration
|
||||
return filesystem_registration::RegisterFilesystemPluginImpl(&info);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,7 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
@ -304,40 +303,22 @@ static Status ValidatePluginMemoryRoutines(
|
||||
|
||||
namespace filesystem_registration {
|
||||
|
||||
Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
|
||||
// Step 1: Load plugin
|
||||
Env* env = Env::Default();
|
||||
void* dso_handle;
|
||||
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
|
||||
Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info) {
|
||||
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(info));
|
||||
|
||||
// Step 2: Load symbol for `TF_InitPlugin`
|
||||
void* dso_symbol;
|
||||
TF_RETURN_IF_ERROR(
|
||||
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
|
||||
|
||||
// Step 3: Call `TF_InitPlugin`
|
||||
TF_FilesystemPluginInfo info;
|
||||
memset(&info, 0, sizeof(info));
|
||||
auto TF_InitPlugin =
|
||||
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
|
||||
TF_InitPlugin(&info);
|
||||
|
||||
// Step 4: Ensure plugin provides the memory management functions.
|
||||
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(&info));
|
||||
|
||||
// Step 5: Validate and register all filesystems
|
||||
// Validate and register all filesystems
|
||||
// Try to register as many filesystems as possible.
|
||||
// Free memory once we no longer need it
|
||||
Status status;
|
||||
for (int i = 0; i < info.num_schemes; i++) {
|
||||
status.Update(ValidateAndRegisterFilesystems(&info, i));
|
||||
info.plugin_memory_free(info.ops[i].scheme);
|
||||
info.plugin_memory_free(info.ops[i].filesystem_ops);
|
||||
info.plugin_memory_free(info.ops[i].random_access_file_ops);
|
||||
info.plugin_memory_free(info.ops[i].writable_file_ops);
|
||||
info.plugin_memory_free(info.ops[i].read_only_memory_region_ops);
|
||||
for (int i = 0; i < info->num_schemes; i++) {
|
||||
status.Update(ValidateAndRegisterFilesystems(info, i));
|
||||
info->plugin_memory_free(info->ops[i].scheme);
|
||||
info->plugin_memory_free(info->ops[i].filesystem_ops);
|
||||
info->plugin_memory_free(info->ops[i].random_access_file_ops);
|
||||
info->plugin_memory_free(info->ops[i].writable_file_ops);
|
||||
info->plugin_memory_free(info->ops[i].read_only_memory_region_ops);
|
||||
}
|
||||
info.plugin_memory_free(info.ops);
|
||||
info->plugin_memory_free(info->ops);
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -15,12 +15,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace filesystem_registration {
|
||||
|
||||
Status RegisterFilesystemPluginImpl(const std::string& dso_path);
|
||||
// Implementation for filesystem registration
|
||||
//
|
||||
// Don't call this directly. Instead call `RegisterFilesystemPlugin`.
|
||||
// Exposed only for static registration of local filesystems.
|
||||
Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info);
|
||||
|
||||
} // namespace filesystem_registration
|
||||
} // namespace tensorflow
|
||||
|
@ -19,6 +19,7 @@ tf_cc_shared_object(
|
||||
cc_library(
|
||||
name = "posix_filesystem_impl",
|
||||
srcs = ["posix_filesystem.cc"],
|
||||
hdrs = ["posix_filesystem.h"],
|
||||
deps = [
|
||||
":posix_filesystem_helper",
|
||||
"//tensorflow/c:tf_status",
|
||||
@ -26,6 +27,20 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# Since building pip package and API tests require a filesystem, we provide a
|
||||
# static registration target that they should link against.
|
||||
cc_library(
|
||||
name = "posix_filesystem_static",
|
||||
srcs = ["posix_filesystem_static.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":posix_filesystem_impl",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"//tensorflow/c/experimental/filesystem:modular_filesystem",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Library implementing helper functionality, so that the above only contains
|
||||
# the API implementation for modular filesystems.
|
||||
cc_library(
|
||||
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h"
|
||||
|
||||
#include <dirent.h>
|
||||
#include <errno.h>
|
||||
#include <fcntl.h>
|
||||
@ -24,7 +26,6 @@ limitations under the License.
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
|
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
|
||||
// Initialize the POSIX filesystem.
|
||||
//
|
||||
// In general, the `TF_InitPlugin` symbol doesn't need to be exposed in a header
|
||||
// file, since the plugin registration will look for the symbol in the DSO file
|
||||
// that provides the filesystem functionality. However, the POSIX filesystem
|
||||
// needs to be statically registered in some tests and utilities for building
|
||||
// the API files at the time of creating the pip package. Hence, we need to
|
||||
// expose this function so that this filesystem can be statically registered
|
||||
// when needed.
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info);
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
|
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Register the POSIX filesystems statically.
|
||||
// Return value will be unused
|
||||
bool StaticallyRegisterLocalFilesystems() {
|
||||
TF_FilesystemPluginInfo info;
|
||||
TF_InitPlugin(&info);
|
||||
Status status = filesystem_registration::RegisterFilesystemPluginImpl(&info);
|
||||
if (!status.ok()) {
|
||||
VLOG(0) << "Static POSIX filesystem could not be registered: " << status;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Perform the actual registration
|
||||
static bool unused = StaticallyRegisterLocalFilesystems();
|
||||
|
||||
} // namespace tensorflow
|
@ -632,6 +632,7 @@ tf_gen_op_wrappers_cc(
|
||||
"tpu_configuration_ops",
|
||||
"tpu_cross_replica_ops",
|
||||
"tpu_embedding_ops",
|
||||
"tpu_embedding_load_retrieve_ops",
|
||||
"tpu_functional_ops",
|
||||
"tpu_heartbeat_ops",
|
||||
"tpu_host_compute_ops",
|
||||
|
@ -521,15 +521,15 @@ Status SymbolicGradientBuilder::AddGradients() {
|
||||
// gradient function to the src node/output to which it should be
|
||||
// backpropped. Maybe grad functions can return a vector of Output pairs to
|
||||
// make this association explicit.
|
||||
size_t dx_index = 0;
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
if (dx_index == dx.size()) {
|
||||
int dx_index = e->dst_input();
|
||||
if (dx_index >= dx.size()) {
|
||||
return errors::Internal(
|
||||
"Invalid gradient output index: ", dx_index, " size: ", dx.size());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
BackpropAlongEdge(dx[dx_index++], {e->src(), e->src_output()}));
|
||||
BackpropAlongEdge(dx[dx_index], {e->src(), e->src_output()}));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -503,6 +503,42 @@ TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) {
|
||||
EXPECT_EQ(grad_result[0].flat<float>()(0), 17610.0f);
|
||||
}
|
||||
|
||||
TEST_F(GradientsTest, AddSymbolicGradientsTest) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
for (int cnt = 0; cnt < 100; ++cnt) {
|
||||
int N = 5 + rand() % 10;
|
||||
// Construct forward graph.
|
||||
OutputList inputs;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
auto a = Const(scope, i, {1});
|
||||
inputs.push_back(a);
|
||||
}
|
||||
|
||||
auto pack = Stack(scope, inputs);
|
||||
TF_ASSERT_OK(scope.status());
|
||||
|
||||
// Construct grad inputs.
|
||||
OutputList output_grads;
|
||||
Tensor ts(DT_INT32, {N, 1});
|
||||
auto v = ts.matrix<int32>();
|
||||
for (int i = 0; i < N; ++i) {
|
||||
v(i, 0) = i;
|
||||
}
|
||||
auto dy = Const(scope, ts);
|
||||
output_grads.push_back(dy);
|
||||
// Call AddSymbolicGradients.
|
||||
std::vector<Output> grad_outputs;
|
||||
TF_ASSERT_OK(AddSymbolicGradients(scope, {pack.output}, inputs,
|
||||
output_grads, &grad_outputs));
|
||||
ClientSession session((scope));
|
||||
std::vector<Tensor> in_grad;
|
||||
TF_ASSERT_OK(session.Run(grad_outputs, &in_grad));
|
||||
for (int i = 0; i < N; ++i) {
|
||||
test::ExpectTensorEqual<int>(in_grad[i], test::AsTensor<int>({i}, {1}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
|
||||
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
|
||||
// a single nodes output.
|
||||
|
@ -170,7 +170,9 @@ Status GenArgMethods(const tf2xla::Config& config,
|
||||
const xla::ProgramShapeProto& ps,
|
||||
const CompileResult& compile_result, string* methods) {
|
||||
size_t num_args = ps.parameters_size();
|
||||
if (config.feed_size() + config.variable_size() != num_args) {
|
||||
// feed_size() + variable_size() is the maximum number of args as an
|
||||
// implementation may not create an argument for an unused variable.
|
||||
if (config.feed_size() + config.variable_size() < num_args) {
|
||||
return errors::InvalidArgument(
|
||||
"mismatch between feed_size(", config.feed_size(), ")+variable_size(",
|
||||
config.variable_size(), ") and num_args(", num_args, ")");
|
||||
|
@ -154,7 +154,7 @@ static void CompareWithGoldenFile(
|
||||
// To update the golden file, flip update_golden to true and run the
|
||||
// following:
|
||||
// bazel test --test_strategy=local \
|
||||
// third_party/tensorflow/compiler/aot:codegen_test
|
||||
// "third_party/tensorflow/compiler/aot:codegen_test"
|
||||
const bool update_golden = false;
|
||||
string golden_file_name =
|
||||
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
||||
|
@ -157,6 +157,7 @@ def tftop_k(_):
|
||||
|
||||
def tfvariable_readonly(_):
|
||||
x = variables.Variable(1000.0, name='x')
|
||||
unused_y = variables.Variable(1000.0, name='y')
|
||||
old_x = x.value()
|
||||
with ops.control_dependencies([old_x]):
|
||||
new_value = math_ops.add(old_x, 42.0)
|
||||
|
@ -10,3 +10,11 @@ variable {
|
||||
type: DT_FLOAT
|
||||
readonly: true
|
||||
}
|
||||
|
||||
variable {
|
||||
node_name: "y"
|
||||
shape {
|
||||
}
|
||||
type: DT_FLOAT
|
||||
readonly: true
|
||||
}
|
||||
|
@ -338,6 +338,7 @@ cc_library(
|
||||
deps = [
|
||||
":xla_activity_listener",
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -299,7 +299,11 @@ Status ComputeIncompatibleResourceOperationPairs(
|
||||
result->push_back({incoming_op.first, n->id()});
|
||||
}
|
||||
|
||||
resource_op_set->Add({n->id(), *op_kind});
|
||||
// Some graphs might have a lot of 'kRead' kinds, but they are always safe
|
||||
// for incoming ops, so not storing them might save a lot of memory.
|
||||
if (op_kind != XlaResourceOpKind::kRead) {
|
||||
resource_op_set->Add({n->id(), *op_kind});
|
||||
}
|
||||
}
|
||||
|
||||
if (vlog) {
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -33,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
@ -40,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
@ -273,8 +276,30 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
|
||||
const NodeDef& node_def = ctx->op_kernel().def();
|
||||
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
|
||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||
std::move(graph), args, result);
|
||||
|
||||
bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) {
|
||||
return arg.kind == XlaCompiler::Argument::kParameter;
|
||||
});
|
||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
||||
// Use MLIR bridge if all the arguments are parameters.
|
||||
// TODO(hinsu): Support other argument types instead of silently falling
|
||||
// back to the XLA compiler.
|
||||
if (!are_params || !use_mlir) {
|
||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||
std::move(graph), args, result);
|
||||
}
|
||||
|
||||
absl::InlinedVector<TensorShape, 4> arg_shapes;
|
||||
arg_shapes.reserve(args.size());
|
||||
for (const XlaCompiler::Argument& arg : args) {
|
||||
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()},
|
||||
compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info,
|
||||
options.shape_representation_fn, result);
|
||||
};
|
||||
return CompileImpl(options, name, args, compile_op,
|
||||
/*compile_threshold=*/absl::nullopt,
|
||||
|
@ -78,7 +78,9 @@ class XlaCompilationCache : public ResourceBase {
|
||||
xla::LocalExecutable** out_executable);
|
||||
|
||||
// As above, but calls XlaCompiler::CompileSingleOp instead of
|
||||
// XlaCompiler::CompileFunction.
|
||||
// XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto
|
||||
// in OpKernelContext, then uses MLIR bridge for compilation instead of
|
||||
// XlaCompiler, if possible.
|
||||
Status CompileSingleOp(
|
||||
const XlaCompiler::Options& options,
|
||||
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
|
||||
|
@ -71,6 +71,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
@ -88,7 +89,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
||||
"//tensorflow/compiler/mlir/xla:xla_test_passes",
|
||||
"@llvm-project//mlir:AffineOps",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
)
|
||||
|
@ -26,6 +26,7 @@ package_group(
|
||||
filegroup(
|
||||
name = "tensorflow_lite_ops_td_files",
|
||||
srcs = [
|
||||
"experimental/tfl_hardware_interfaces.td",
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
@ -204,6 +205,8 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tensorflow_lite",
|
||||
srcs = [
|
||||
"experimental/estimators/estimator.h",
|
||||
"experimental/estimators/gpu_estimator.h.inc",
|
||||
"ir/tfl_ops.cc",
|
||||
"ir/tfl_ops.cc.inc",
|
||||
"ir/tfl_ops.h.inc",
|
||||
@ -213,6 +216,7 @@ cc_library(
|
||||
"utils/attribute_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"experimental/estimators/hardware.h",
|
||||
"ir/tfl_ops.h",
|
||||
"transforms/passes.h",
|
||||
"utils/attribute_utils.h",
|
||||
@ -222,7 +226,6 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
@ -419,7 +422,9 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -439,6 +444,7 @@ genrule(
|
||||
srcs = [
|
||||
"ir/tfl_ops.td",
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"experimental/tfl_hardware_interfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
],
|
||||
@ -551,14 +557,14 @@ cc_library(
|
||||
cc_library(
|
||||
name = "flatbuffer_translate_lib",
|
||||
srcs = [
|
||||
"flatbuffer_export.cc",
|
||||
"flatbuffer_import.cc",
|
||||
"flatbuffer_translate.cc",
|
||||
"utils/convert_type.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"flatbuffer_export.h",
|
||||
"flatbuffer_export_flags.h",
|
||||
"flatbuffer_import.h",
|
||||
"flatbuffer_translate.h",
|
||||
"flatbuffer_translate_flags.h",
|
||||
"utils/convert_type.h",
|
||||
],
|
||||
deps = [
|
||||
@ -575,9 +581,10 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite:string_util",
|
||||
@ -598,15 +605,37 @@ cc_library(
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flatbuffer_translate_registeration",
|
||||
srcs = [
|
||||
"flatbuffer_translate.cc",
|
||||
],
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LoopOpsTransforms",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "flatbuffer_translate",
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
"@llvm-project//mlir:LoopOpsTransforms",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
":flatbuffer_translate_registeration",
|
||||
],
|
||||
)
|
||||
|
||||
@ -644,10 +673,13 @@ filegroup(
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf_tfl_translate",
|
||||
srcs = [":tf_tfl_translate_main"],
|
||||
srcs = [
|
||||
":tf_tfl_translate_main",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":flatbuffer_translate_lib",
|
||||
":flatbuffer_translate_registeration",
|
||||
":tensorflow_lite",
|
||||
":tf_tfl_passes",
|
||||
":tf_tfl_translate_cl_options",
|
||||
@ -669,15 +701,18 @@ tf_cc_binary(
|
||||
|
||||
tf_cc_binary(
|
||||
name = "mlir-tflite-runner",
|
||||
srcs = ["mlir_tflite_runner.cc"],
|
||||
srcs = [
|
||||
"mlir_tflite_runner.cc",
|
||||
],
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
":flatbuffer_translate_registeration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -27,10 +27,10 @@ limitations under the License.
|
||||
#include "llvm/TableGen/Main.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Format.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Operator.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Attribute.h" // from @llvm-project
|
||||
#include "mlir/TableGen/Format.h" // from @llvm-project
|
||||
#include "mlir/TableGen/Operator.h" // from @llvm-project
|
||||
#include "mlir/TableGen/Predicate.h" // from @llvm-project
|
||||
|
||||
using llvm::DefInit;
|
||||
using llvm::dyn_cast;
|
||||
@ -442,28 +442,26 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
verify_ctx.withOp("top");
|
||||
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
auto &value = op.getOperand(i);
|
||||
// Skip from from first variadic operands for now. Else getOperand index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||
}
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
auto &value = op.getResult(i);
|
||||
// Skip from from first variadic results for now. Else getResult index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||
}
|
||||
auto &value = op.getOperand(i);
|
||||
// Skip from from first variadic operands for now. Else getOperand index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
|
||||
}
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
auto &value = op.getResult(i);
|
||||
// Skip from from first variadic results for now. Else getResult index
|
||||
// used below doesn't match.
|
||||
if (value.isVariadic()) break;
|
||||
if (!value.name.empty())
|
||||
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
|
||||
}
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
|
||||
"operand");
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
||||
"result");
|
||||
os << " return mlir::success();\n}\n";
|
||||
os << " return top.verify();\n}\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <cstdarg>
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/hardware.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||
|
||||
template <typename Op, typename TargetHardware>
|
||||
class TFLiteCostEstimator {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined support for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// All ops on CPU are supported.
|
||||
// TODO(karimnosseir): Only allow TFL ops in the "TFL_OP" param.
|
||||
template <typename TFL_OP>
|
||||
class TFLiteCostEstimator<TFL_OP, hardware::CPU> {
|
||||
public:
|
||||
// TODO(karimnosseir): Update and use table based method and lookup
|
||||
// cost from a loadable table ?
|
||||
static double GetCost(mlir::Operation* op) { return 0.0; }
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_
|
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_
|
||||
|
||||
template <>
|
||||
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_HARDWARE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_HARDWARE_H_
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
|
||||
namespace hardware {
|
||||
// Empty classes that represents hardware types.
|
||||
class CPU {};
|
||||
class GPU {};
|
||||
} // namespace hardware
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_HARDWARE_H_
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// WARNING: This Interface is experimental, DO NOT USE.
|
||||
|
||||
// This is the Target Hardware operation interfacea definition file
|
||||
// for TensorFlow Lite.
|
||||
|
||||
#ifndef TFL_TARGET_HARDWARE_OP_INTERFACES
|
||||
#define TFL_TARGET_HARDWARE_OP_INTERFACES
|
||||
|
||||
def TFL_CpuTargetOp : OpInterface<"CpuOpTargetInterface"> {
|
||||
let description = [{
|
||||
Interface for ops to run on CPU.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the cost of running this op on CPU.}],
|
||||
// TODO(karimnosseir): Change to return Cost object instead.
|
||||
"double", "GetOpCost", (ins "mlir::Operation*":$op_to_check), [{
|
||||
// TODO(karimnosseir): Consider changing to another way that doesn't
|
||||
// rely on template param name.
|
||||
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::CPU>::GetCost(op_to_check);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Returns whether this op can be run on CPU.}],
|
||||
"bool", "IsSupported", (ins "mlir::Operation*":$op_to_check), [{
|
||||
// TODO(karimnosseir): Consider changing to another way that doesn't
|
||||
// rely on template param name.
|
||||
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::CPU>::IsSupported(op_to_check);
|
||||
}]
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_GpuTargetOp : OpInterface<"GpuOpTargetInterface"> {
|
||||
let description = [{
|
||||
Interface for ops to run on GPU.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the cost of running this op on GPU.}],
|
||||
// TODO(karimnosseir): Change to return Cost object instead.
|
||||
"double", "GetOpCost", (ins "Operation*":$op_to_check), [{
|
||||
// TODO(karimnosseir): Consider changing to another way that doesn't
|
||||
// rely on template param name.
|
||||
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::GPU>::GetCost(op_to_check);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Returns whether this op can be run on GPU.}],
|
||||
"bool", "IsSupported", (ins "Operation*":$op_to_check), [{
|
||||
// TODO(karimnosseir): Consider changing to another way that doesn't
|
||||
// rely on template param name.
|
||||
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::GPU>::IsSupported(op_to_check);
|
||||
}]
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TFL_TARGET_HARDWARE_OP_INTERFACES
|
1454
tensorflow/compiler/mlir/lite/flatbuffer_export.cc
Normal file
1454
tensorflow/compiler/mlir/lite/flatbuffer_export.cc
Normal file
File diff suppressed because it is too large
Load Diff
43
tensorflow/compiler/mlir/lite/flatbuffer_export.h
Normal file
43
tensorflow/compiler/mlir/lite/flatbuffer_export.h
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Translates the given MLIR `module` into a FlatBuffer and stores the
|
||||
// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to
|
||||
// convert location of the op to name in flatbuffer. Returns true if translation
|
||||
// fails, otherwise returns false.
|
||||
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
|
||||
std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops,
|
||||
bool emit_custom_ops);
|
||||
|
||||
// Same as the above but with a custom op name mapper.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
|
31
tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h
Normal file
31
tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
// These flags are used to control the emission or not of different kinds of ops
|
||||
// during the flatbuffer translation.
|
||||
extern bool emit_builtin_tflite_ops;
|
||||
extern bool emit_select_tf_ops;
|
||||
extern bool emit_custom_ops;
|
||||
// The flag to control whether to lower tensorlist ops into TF ops.
|
||||
extern bool lower_tensor_list_ops;
|
||||
// The flag to control whether debug info gets stripped on export.
|
||||
extern bool strip_debug_info;
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_
|
@ -44,39 +44,35 @@ limitations under the License.
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Translation.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@ -100,45 +96,6 @@ using xla::StatusOr;
|
||||
namespace errors = tensorflow::errors;
|
||||
namespace tfl = mlir::TFL;
|
||||
|
||||
using llvm::cl::opt;
|
||||
|
||||
// Commandline flag to enable the control of flatbuffer import.
|
||||
bool use_external_constant;
|
||||
|
||||
// Commandline flag to enable graph pruning.
|
||||
bool experimental_prune_unreachable_nodes_unconditionally;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool, true> use_external_constant_flag(
|
||||
"use-external-constant",
|
||||
llvm::cl::desc("Use external constant during flatbuffer import"),
|
||||
llvm::cl::location(use_external_constant), llvm::cl::init(false));
|
||||
|
||||
// TODO(b/147111261): After the importer supports generic custom ops, we should
|
||||
// change the flag to a more lightwise flag, e.g.
|
||||
// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
|
||||
// the operations.
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
|
||||
"experimental-prune-unreachable-nodes-unconditionally",
|
||||
llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
|
||||
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string> input_arrays_flag(
|
||||
"input-arrays",
|
||||
llvm::cl::desc(
|
||||
"List of input tensors, if different from the default inputs"),
|
||||
llvm::cl::init(""));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string> output_arrays_flag(
|
||||
"output-arrays",
|
||||
llvm::cl::desc(
|
||||
"List of output tensors, if different from the default outputs"),
|
||||
llvm::cl::init(""));
|
||||
|
||||
namespace {
|
||||
bool IsScalar(const TensorT& tensor) {
|
||||
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
||||
@ -1063,42 +1020,3 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
|
||||
return OwningModuleRef(module);
|
||||
}
|
||||
|
||||
static OwningModuleRef FlatBufferFileToMlirTrans(
|
||||
llvm::SourceMgr* source_mgr, MLIRContext* context,
|
||||
bool use_external_constant,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
const llvm::MemoryBuffer* input =
|
||||
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
|
||||
std::string error;
|
||||
auto loc =
|
||||
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
||||
|
||||
// Parses input/output names from command line options.
|
||||
std::vector<std::string> inputs;
|
||||
std::vector<std::string> outputs;
|
||||
// Use output parser since we only have tensor names.
|
||||
if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) {
|
||||
return emitError(loc, "parsing input array info failed ")
|
||||
<< input_arrays_flag,
|
||||
nullptr;
|
||||
}
|
||||
if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) {
|
||||
return emitError(loc, "parsing output array info failed ")
|
||||
<< output_arrays_flag,
|
||||
nullptr;
|
||||
}
|
||||
|
||||
return tflite::FlatBufferToMlir(
|
||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||
context, loc, use_external_constant, inputs, outputs,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
}
|
||||
|
||||
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
||||
"tflite-flatbuffer-to-mlir",
|
||||
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
|
||||
return FlatBufferFileToMlirTrans(
|
||||
&source_mgr, context, use_external_constant,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
});
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
|
||||
namespace tflite {
|
||||
// Converts a TFLite flatbuffer stored in `buffer` to a MLIR module
|
||||
|
@ -18,12 +18,12 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
@ -23,12 +23,12 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
|
@ -23,8 +23,8 @@ limitations under the License.
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
|
||||
#include "flatbuffers/minireflect.h" // TF:flatbuffers
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "flatbuffers/minireflect.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/schema/reflection/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#define TFL_OP_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for stateful operands.
|
||||
|
@ -26,19 +26,19 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/FoldUtils.h" // from @llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -804,10 +804,10 @@ struct RemoveAdjacentReshape : public RewritePattern {
|
||||
RemoveAdjacentReshape(MLIRContext *context)
|
||||
: RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
LogicalResult match(Operation *op) const override {
|
||||
auto thisOp = cast<ReshapeOp>(op);
|
||||
auto prevOp = thisOp.getOperand(0).getDefiningOp();
|
||||
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
|
||||
return isa_and_nonnull<ReshapeOp>(prevOp) ? success() : failure();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
@ -884,28 +884,27 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
|
||||
explicit RemoveRedundantUnpackPack(MLIRContext *context)
|
||||
: RewritePattern(PackOp::getOperationName(), 2, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TFL::PackOp pack_op = cast<TFL::PackOp>(op);
|
||||
Operation *first_input = pack_op.getOperand(0).getDefiningOp();
|
||||
if (!first_input) return matchFailure();
|
||||
if (!first_input) return failure();
|
||||
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
|
||||
if (!input_unpack_op) return matchFailure();
|
||||
if (!input_unpack_op) return failure();
|
||||
|
||||
// The unpack & pack should have the same axis & num inputs/outputs.
|
||||
if (pack_op.axis() != input_unpack_op.axis() ||
|
||||
pack_op.values_count() != input_unpack_op.num())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
const int total_pack_inputs = pack_op.getNumOperands();
|
||||
if (total_pack_inputs != input_unpack_op.getNumResults())
|
||||
return matchFailure();
|
||||
if (total_pack_inputs != input_unpack_op.getNumResults()) return failure();
|
||||
for (auto input_output :
|
||||
llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
|
||||
Value pack_input = std::get<0>(input_output);
|
||||
Value unpack_output = std::get<1>(input_output);
|
||||
// Make sure the ordering is the same for the pack op & unpack op.
|
||||
if (pack_input != unpack_output) return matchFailure();
|
||||
if (pack_input != unpack_output) return failure();
|
||||
}
|
||||
|
||||
// Replace the pack's output to the unpack's input.
|
||||
@ -913,7 +912,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
|
||||
// At this point, we don't manually remove the redundant pack op & unpack op
|
||||
// (we cannot actually), but trust the PatterRewriter to garbage collect
|
||||
// these two ops.
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -1050,17 +1049,17 @@ struct DropFakeQuant : public RewritePattern {
|
||||
explicit DropFakeQuant(MLIRContext *context)
|
||||
: RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
LogicalResult match(Operation *op) const override {
|
||||
// We only match the op with valid "minmax" attribute.
|
||||
if (!HasValidMinMaxAttribute(op)) return matchFailure();
|
||||
if (!HasValidMinMaxAttribute(op)) return failure();
|
||||
|
||||
// If all the users of this op have valid "minmax" attributes, it is matched
|
||||
// and can be removed.
|
||||
auto fakeQuantOp = cast<FakeQuantOp>(op);
|
||||
for (auto *operand : fakeQuantOp.getResult().getUsers())
|
||||
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
|
||||
if (!HasValidMinMaxAttribute(operand)) return failure();
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
@ -1789,8 +1788,8 @@ struct WhileResultOperandsMatchAndImplicitCapture
|
||||
: public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(WhileOp while_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(WhileOp while_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Replace values simply passed through the body with extern values. The
|
||||
// block arguments of body and while match and so the corresponding cond
|
||||
// argument can be easily found.
|
||||
@ -1843,7 +1842,7 @@ struct WhileResultOperandsMatchAndImplicitCapture
|
||||
}
|
||||
|
||||
// Done if no values removed from blocks and operands & results match.
|
||||
if (unchanged) return matchFailure();
|
||||
if (unchanged) return failure();
|
||||
|
||||
// Replace with new While with matching operands and results.
|
||||
Operation *op = while_op.getOperation();
|
||||
@ -1866,7 +1865,7 @@ struct WhileResultOperandsMatchAndImplicitCapture
|
||||
rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
|
||||
new_body_yield);
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -18,18 +18,18 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
|
||||
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Dialect.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // TF:llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // TF:llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@ -49,9 +49,12 @@ class TensorFlowLiteDialect : public Dialect {
|
||||
Location loc) override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||
// Include all specializes estimators below this line
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc"
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
@ -285,7 +285,10 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
|
||||
|
||||
class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<TFL_Dialect, mnemonic, !listconcat(traits,
|
||||
[DeclareOpInterfaceMethods<TFL_RuntimeVerification>])> {
|
||||
[DeclareOpInterfaceMethods<TFL_RuntimeVerification>,
|
||||
// All TFL ops are supported on CPU.
|
||||
DeclareOpInterfaceMethods<TFL_CpuTargetOp>
|
||||
])> {
|
||||
// FlatBuffer generation specific information.
|
||||
// -------------------------------------------
|
||||
// When generating the FlatBuffer output some operations have
|
||||
@ -477,7 +480,10 @@ Note this is a custom op that is not supported in the standard runtime.
|
||||
}
|
||||
|
||||
def TFL_AveragePool2DOp:
|
||||
TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
TFL_Op<"average_pool_2d",
|
||||
[NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Average_pool_2d operator";
|
||||
|
||||
let description = [{
|
||||
@ -690,7 +696,7 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
|
||||
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
}];
|
||||
}
|
||||
@ -715,7 +721,7 @@ def TFL_DepthwiseConv2DOp :
|
||||
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 3; }
|
||||
}];
|
||||
}
|
||||
@ -1083,7 +1089,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||
}
|
||||
|
||||
def TFL_DivOp : TFL_Op<"div", [ResultsBroadcastableShape, NoSideEffect]> {
|
||||
def TFL_DivOp : TFL_Op<"div", [
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Division operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -30,12 +30,12 @@ limitations under the License.
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Parser.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
|
@ -17,11 +17,11 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
|
@ -18,11 +18,11 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include <ostream>
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
|
@ -23,18 +23,18 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineMap.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/AffineExpr.h" // from @llvm-project
|
||||
#include "mlir/IR/AffineMap.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
|
@ -17,14 +17,14 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#define TF_Quantization
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/QuantOps/QuantOpsBase.td"
|
||||
include "mlir/Dialect/Quant/QuantOpsBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// QuantizedType definitions.
|
||||
|
@ -24,18 +24,18 @@ limitations under the License.
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
|
||||
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
@ -22,15 +22,15 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantizeUtils.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
@ -23,18 +23,18 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -82,17 +82,17 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
narrow_range(narrow_range),
|
||||
is_signed(is_signed) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(quant::StatisticsOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
LogicalResult matchAndRewrite(quant::StatisticsOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Type expressed = op.getType().cast<ShapedType>().getElementType();
|
||||
quant::QuantizedType quant_type;
|
||||
SmallVector<double, 4> mins, maxs;
|
||||
|
||||
if (op.axisStats().hasValue()) {
|
||||
int stats_num = op.axisStats()->getNumElements();
|
||||
if (stats_num == 0 || stats_num % 2 != 0) return this->matchFailure();
|
||||
if (stats_num == 0 || stats_num % 2 != 0) return failure();
|
||||
auto stats = op.axisStats()->dyn_cast<DenseFPElementsAttr>();
|
||||
if (!stats) return this->matchFailure();
|
||||
if (!stats) return failure();
|
||||
|
||||
for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
|
||||
mins.push_back(FloatAttr::getValueAsDouble(*it++));
|
||||
@ -108,7 +108,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
quant::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax,
|
||||
narrow_range, expressed, is_signed);
|
||||
} else {
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
@ -119,7 +119,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
q.getOperation()->replaceUsesOfWith(dq, op.arg());
|
||||
op.erase();
|
||||
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
@ -156,16 +156,16 @@ struct QuantizationPattern : public RewritePattern {
|
||||
error_tolerance(error_tolerance),
|
||||
single_layer_verify(single_layer_verify) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
if (op->getNumResults() != 1) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
Value quantized_value = op->getResult(0);
|
||||
for (Operation* quantized_op : quantized_value.getUsers()) {
|
||||
// If it is requantize op, we shouldn't rewrite this op.
|
||||
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// If it is terminator or not quantizable or any ops form the mlir quant
|
||||
@ -174,7 +174,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::QuantizeCastOp>(quantized_op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(quantized_op)) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Collect all the quantized inputs and "clone" the matched op by these
|
||||
@ -198,7 +198,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
} else if (static_cast<const ConcretTy*>(this)->AllowHybridOperand()) {
|
||||
inputs.push_back(operand);
|
||||
} else {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
@ -234,7 +234,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
outputs_replaced.insert({result, enumerated_result.index()});
|
||||
output_types.push_back(result.getType());
|
||||
} else {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
@ -299,7 +299,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
}
|
||||
}
|
||||
}
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
bool enable_verify;
|
||||
@ -317,11 +317,11 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
explicit ConvertUnsignedToSigned(MLIRContext* context)
|
||||
: OpRewritePattern<Q>(context, 1) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Q op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Q op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Type output_type = op.getResult().getType();
|
||||
auto qtype = QType::getQuantizedElementType(output_type);
|
||||
if (!qtype || qtype.isSigned()) return this->matchFailure();
|
||||
if (!qtype || qtype.isSigned()) return failure();
|
||||
|
||||
int num_bits = qtype.getStorageTypeIntegralWidth();
|
||||
// This is a positive value, and will be applied on zero points and fixed
|
||||
@ -352,14 +352,14 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
aqtype.getStorageTypeMin() - offset,
|
||||
aqtype.getStorageTypeMax() - offset, op.getLoc());
|
||||
} else {
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!new_qtype) return this->matchFailure();
|
||||
if (!new_qtype) return failure();
|
||||
Type new_output_type = new_qtype.castFromExpressedType(
|
||||
QType::castToExpressedType(output_type));
|
||||
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
@ -73,12 +73,12 @@ struct InsertQuantOpsAfterTFFakeQuantOp
|
||||
MLIRContext *ctx)
|
||||
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(TFFakeQuantOp tf_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We don't want to insert quantize/dequantize if the quantize op exists.
|
||||
auto res = tf_op.outputs();
|
||||
if (!res.hasOneUse() || isa<quant::QuantizeCastOp>(*res.user_begin()))
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
|
||||
// Extract the min/max constant values from the operands. We also consider
|
||||
// a special case that there are tf.Identity ops between the min/max
|
||||
@ -95,8 +95,8 @@ struct InsertQuantOpsAfterTFFakeQuantOp
|
||||
max = tf_op.max();
|
||||
rewriter.eraseOp(id2);
|
||||
}
|
||||
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
|
||||
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
|
||||
if (!matchPattern(min, m_Constant(&min_value))) return failure();
|
||||
if (!matchPattern(max, m_Constant(&max_value))) return failure();
|
||||
|
||||
int quant_dim = -1;
|
||||
if (PerAxis) {
|
||||
@ -114,7 +114,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp
|
||||
TypeAttr qtype = quant::GetQuantizedTypeAttr(
|
||||
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
|
||||
narrow_range, /*is_signed=*/true);
|
||||
if (!qtype) this->matchFailure();
|
||||
if (!qtype) failure();
|
||||
|
||||
// Finally, use the quantization parameter to create the quantize and
|
||||
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
|
||||
@ -127,7 +127,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "llvm/TableGen/Main.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Operator.h" // TF:llvm-project
|
||||
#include "mlir/TableGen/Operator.h" // from @llvm-project
|
||||
|
||||
using llvm::LessRecord;
|
||||
using llvm::raw_ostream;
|
||||
|
@ -1,3 +1,8 @@
|
||||
load(
|
||||
"//third_party/mlir:tblgen.bzl",
|
||||
"gentbl",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
@ -18,6 +23,8 @@ package_group(
|
||||
cc_library(
|
||||
name = "hlo_xla_quantization_passes",
|
||||
srcs = [
|
||||
"cpu_kernel_fusion.cc",
|
||||
"generated_cpu_kernel_fusion.inc",
|
||||
"materialize.cc",
|
||||
"op_quant_spec.inc",
|
||||
"propagate.cc",
|
||||
@ -36,6 +43,8 @@ cc_library(
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -52,7 +61,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
@ -62,6 +70,24 @@ cc_library(
|
||||
"//tensorflow/core/platform:status",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "cpu_kernel_fusion_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"generated_cpu_kernel_fusion.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "cpu_kernel_fusion.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_ops_td_files",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,252 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||
|
||||
#define DEBUG_TYPE "quant-kernel-fusion"
|
||||
|
||||
constexpr int kFakeQuantOperandsNum = 5;
|
||||
constexpr int kFakeQuantPerChannelOperandsNum = 6;
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
TypeAttr GetQuantSpec(Operation* op) {
|
||||
auto fake_quant = llvm::dyn_cast_or_null<CustomCallOp>(op);
|
||||
if (!fake_quant || fake_quant.getNumOperands() < kFakeQuantOperandsNum ||
|
||||
fake_quant.getNumOperands() > kFakeQuantPerChannelOperandsNum ||
|
||||
fake_quant.call_target_name() != "fake_quant_with_min_max_vars")
|
||||
return {};
|
||||
|
||||
DenseFPElementsAttr min, max;
|
||||
DenseIntElementsAttr bit_width, narrow_range, quant_dim;
|
||||
if (!matchPattern(fake_quant.getOperand(1), m_Constant(&min)) ||
|
||||
!matchPattern(fake_quant.getOperand(2), m_Constant(&max)) ||
|
||||
!matchPattern(fake_quant.getOperand(3), m_Constant(&bit_width)) ||
|
||||
!matchPattern(fake_quant.getOperand(4), m_Constant(&narrow_range)))
|
||||
return {};
|
||||
|
||||
auto bit_width_val = (*bit_width.attr_value_begin()).cast<IntegerAttr>();
|
||||
auto narrow_range_val = (*narrow_range.int_value_begin()).getSExtValue();
|
||||
int quant_dim_val = -1;
|
||||
if (fake_quant.getNumOperands() == kFakeQuantPerChannelOperandsNum &&
|
||||
matchPattern(fake_quant.getOperand(kFakeQuantPerChannelOperandsNum - 1),
|
||||
m_Constant(&quant_dim))) {
|
||||
quant_dim_val = (*quant_dim.int_value_begin()).getSExtValue();
|
||||
}
|
||||
|
||||
OpBuilder builder(op);
|
||||
Type input_type =
|
||||
fake_quant.getOperand(0).getType().cast<ShapedType>().getElementType();
|
||||
return quant::GetQuantizedTypeAttr(
|
||||
builder, input_type, min, max, quant_dim_val, bit_width_val,
|
||||
builder.getBoolAttr(narrow_range_val), /*is_signed=*/true);
|
||||
}
|
||||
|
||||
// Collects input values from outside for 'ops'.
|
||||
void CollectInputs(llvm::ArrayRef<Operation*> ops,
|
||||
llvm::SmallVectorImpl<Value>* inputs,
|
||||
llvm::SmallVectorImpl<Attribute>* input_specs) {
|
||||
for (Operation* op : ops) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (std::find(inputs->begin(), inputs->end(), operand) != inputs->end()) {
|
||||
continue;
|
||||
}
|
||||
if (Operation* def_op = operand.getDefiningOp()) {
|
||||
if (std::find(ops.begin(), ops.end(), def_op) == ops.end()) {
|
||||
inputs->push_back(operand);
|
||||
}
|
||||
} else { // argument value
|
||||
inputs->push_back(operand);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Value input : *inputs) {
|
||||
ShapedType input_type = input.getType().cast<ShapedType>();
|
||||
if (TypeAttr spec = GetQuantSpec(input.getDefiningOp())) {
|
||||
input_specs->push_back(spec);
|
||||
} else {
|
||||
input_specs->push_back(TypeAttr::get(input_type.getElementType()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collects values that are produced by 'ops' and have use outside of 'ops'.
|
||||
// TODO(fengliuai): if it is a single user and QDQ, write that to the specs.
|
||||
void CollectRets(llvm::ArrayRef<Operation*> ops,
|
||||
llvm::SmallVectorImpl<Value>* rets,
|
||||
llvm::SmallVectorImpl<Type>* ret_types,
|
||||
llvm::SmallVectorImpl<Attribute>* ret_specs) {
|
||||
for (Operation* op : ops) {
|
||||
for (Value result : op->getResults()) {
|
||||
for (Operation* user : result.getUsers()) {
|
||||
// If there are any user outside of 'ops'
|
||||
if (std::find(ops.begin(), ops.end(), user) == ops.end()) {
|
||||
ShapedType ret_type = result.getType().cast<ShapedType>();
|
||||
rets->push_back(result);
|
||||
ret_types->push_back(ret_type);
|
||||
if (TypeAttr spec = GetQuantSpec(user)) {
|
||||
ret_specs->push_back(spec);
|
||||
} else {
|
||||
ret_specs->push_back(TypeAttr::get(ret_type.getElementType()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 0> fuseOps(PatternRewriter* rewriter,
|
||||
const std::initializer_list<Value>& results,
|
||||
StringRef kernel) {
|
||||
// Collect all the operations to be fused.
|
||||
llvm::SmallVector<Operation*, 4> fused;
|
||||
llvm::SmallVector<Location, 4> locs;
|
||||
fused.reserve(results.size());
|
||||
locs.reserve(results.size());
|
||||
for (auto value : results) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
fused.push_back(op);
|
||||
locs.push_back(op->getLoc());
|
||||
}
|
||||
|
||||
// Collect inputs from outside to 'ops'.
|
||||
llvm::SmallVector<Value, 4> inputs;
|
||||
llvm::SmallVector<Attribute, 4> input_specs;
|
||||
CollectInputs(fused, &inputs, &input_specs);
|
||||
|
||||
// Collect outputs from 'ops' to outside.
|
||||
llvm::SmallVector<Value, 4> rets;
|
||||
llvm::SmallVector<Type, 4> ret_types;
|
||||
llvm::SmallVector<Attribute, 4> ret_specs;
|
||||
CollectRets(fused, &rets, &ret_types, &ret_specs);
|
||||
|
||||
// Create the region op with the return.
|
||||
auto region = rewriter->create<quant::QuantizeRegionOp>(
|
||||
rewriter->getFusedLoc(locs), ret_types, inputs,
|
||||
rewriter->getArrayAttr(input_specs), rewriter->getArrayAttr(ret_specs),
|
||||
kernel);
|
||||
auto* body = new Block();
|
||||
region.body().push_back(body);
|
||||
|
||||
OpBuilder builder(body);
|
||||
BlockAndValueMapping mapping;
|
||||
|
||||
// Make block arguments and add it to the block value mapping.
|
||||
for (Value input : inputs) {
|
||||
mapping.map(input, body->addArgument(input.getType()));
|
||||
}
|
||||
|
||||
// Clone the operations 'ops' to the region.
|
||||
for (Operation* op : fused) {
|
||||
builder.clone(*op, mapping);
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> new_rets;
|
||||
new_rets.reserve(rets.size());
|
||||
for (auto ret : llvm::enumerate(rets)) {
|
||||
Value new_ret = mapping.lookupOrNull(ret.value());
|
||||
assert(new_ret && "couldn't find return value.");
|
||||
new_rets.push_back(new_ret);
|
||||
ret.value().replaceAllUsesWith(region.getResult(ret.index()));
|
||||
}
|
||||
builder.create<quant::ReturnOp>(builder.getUnknownLoc(), new_rets);
|
||||
|
||||
LLVM_DEBUG({
|
||||
assert(region.verify().Success && "failed to create quant region.");
|
||||
llvm::dbgs() << "\ncreated region: ";
|
||||
region.print(llvm::dbgs());
|
||||
llvm::dbgs() << "\n\n\n";
|
||||
});
|
||||
|
||||
SmallVector<Value, 0> new_values(fused.back()->getNumResults());
|
||||
return new_values;
|
||||
}
|
||||
|
||||
struct CpuKernelFusionPass : public FunctionPass<CpuKernelFusionPass> {
|
||||
explicit CpuKernelFusionPass() = default;
|
||||
CpuKernelFusionPass(const CpuKernelFusionPass&) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
LogicalResult fuseCpuKernels(Operation* op);
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/generated_cpu_kernel_fusion.inc"
|
||||
|
||||
LogicalResult CpuKernelFusionPass::fuseCpuKernels(Operation* op) {
|
||||
MLIRContext* ctx = op->getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
populateWithGenerated(ctx, &patterns);
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<quant::QuantizationDialect>();
|
||||
target.addLegalOp<CallOp, ModuleOp, FuncOp, ModuleTerminatorOp,
|
||||
::mlir::ReturnOp>();
|
||||
return applyPartialConversion(op, target, patterns);
|
||||
}
|
||||
|
||||
void CpuKernelFusionPass::runOnFunction() {
|
||||
if (failed(fuseCpuKernels(getFunction()))) signalPassFailure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo cpu kernel fusion pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateCpuKernelFusionPass() {
|
||||
return std::make_unique<CpuKernelFusionPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<CpuKernelFusionPass> pass(
|
||||
"xla-hlo-cpu-fusion", "Fuse xla hlo ops into cpu kernels");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -0,0 +1,24 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
class Fused2Ops<string kernel> : NativeCodeCall<
|
||||
"fuseOps(&$_builder, {$0, $1}, \"" # kernel # "\")">;
|
||||
class Fused3Ops<string kernel> : NativeCodeCall<
|
||||
"fuseOps(&$_builder, {$0, $1, $2}, \"" # kernel # "\")">;
|
||||
|
||||
def : Pat<(HLO_AddOp:$add (HLO_MulOp:$mul $_, $_, $_), $_, $_),
|
||||
(Fused2Ops<"generic.mul_add"> $mul, $add)>;
|
@ -25,14 +25,14 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
@ -58,15 +58,15 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
|
||||
explicit RewriteDequantize(int64_t size, MLIRContext *context)
|
||||
: OpRewritePattern<quant::DequantizeCastOp>(context), size_(size) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(quant::DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(quant::DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// quant.dcast
|
||||
// xla_hlo dequantize only takes min/max, so let's recover them from
|
||||
// the quantization parameters.
|
||||
Value dcast = op.arg();
|
||||
auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType());
|
||||
if (!type || !type.isa<quant::UniformQuantizedType>()) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
auto qtype = type.cast<quant::UniformQuantizedType>();
|
||||
double scale = qtype.getScale();
|
||||
@ -77,7 +77,7 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
|
||||
// quant.qcast
|
||||
auto qcast =
|
||||
llvm::dyn_cast_or_null<quant::QuantizeCastOp>(dcast.getDefiningOp());
|
||||
if (!qcast) return matchFailure();
|
||||
if (!qcast) return failure();
|
||||
|
||||
// constant
|
||||
DenseFPElementsAttr attr;
|
||||
@ -88,7 +88,7 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
|
||||
attr.getNumElements() <= size_ ||
|
||||
attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) {
|
||||
op.getResult().replaceAllUsesWith(qcast.arg());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
// TODO(fengliuai): implement transpose if it has high dimension.
|
||||
|
||||
@ -96,7 +96,7 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
|
||||
auto quantized_result =
|
||||
quant::Quantize(attr, qtype).dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!quantized_result) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Pack the uint8 bits to uint32. The shape is changed from from
|
||||
@ -133,7 +133,7 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
|
||||
// Convert bf16 output back to f32
|
||||
rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getResult().getType(),
|
||||
dequantize);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
@ -21,10 +21,10 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
|
@ -14,26 +14,38 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
}
|
||||
|
||||
// Quantizes the model in the computation.
|
||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
||||
xla::XlaComputation* computation) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
|
||||
computation->Snapshot());
|
||||
|
||||
RegisterDialects();
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
auto status = xla::ConvertHloToMlirHlo(
|
||||
|
@ -0,0 +1,46 @@
|
||||
// RUN: tf-opt -xla-hlo-cpu-fusion %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @mul_add_source
|
||||
func @mul_add_source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
|
||||
// CHECK: %[[region:.*]] = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
// CHECK: ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
// CHECK: "quant.return"(%[[add]]) : (tensor<4xf32>) -> ()
|
||||
// CHECK: }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: return %[[region]] : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated
|
||||
func @mul_add_annotated(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) {
|
||||
%cst = constant dense<0.0> : tensor<f32>
|
||||
%cst_0 = constant dense<255.0> : tensor<f32>
|
||||
%cst_1 = constant dense<8> : tensor<i32>
|
||||
%cst_2 = constant dense<false> : tensor<i1>
|
||||
%qin = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
|
||||
has_side_effect = false, name = "custom-call.1"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
%qw = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
|
||||
has_side_effect = false, name = "custom-call.2"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
%0 = "xla_hlo.multiply"(%qin, %qw) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
%r = "xla_hlo.custom_call"(%1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
|
||||
has_side_effect = false, name = "custom-call.3"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
return %r : tensor<2x4xf32>
|
||||
|
||||
// CHECK: %[[region:.*]] = "quant.region"
|
||||
// CHECK: ^bb0(%arg3: tensor<2x4xf32>, %arg4: tensor<2x4xf32>, %arg5: tensor<2x4xf32>): // no predecessors
|
||||
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
|
||||
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
|
||||
// CHECK: "quant.return"(%[[add]]) : (tensor<2x4xf32>) -> ()
|
||||
// CHECK: }) {input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32],
|
||||
// CHECK-SAME: logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]} :
|
||||
// CHECK-SAME: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[r:.*]] = "xla_hlo.custom_call"(%[[region]]
|
||||
// CHECK: return %[[r]] : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
|
@ -6,49 +6,49 @@ func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32,
|
||||
// CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16>
|
||||
// CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[cast]] : tensor<2x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[cast]] : tensor<2x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
|
||||
|
||||
%w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
|
||||
return %mul: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_small
|
||||
func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<1x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<1x4xf32>
|
||||
|
||||
%w = constant dense<1.0> : tensor<1x4xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<1x4xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<1x4xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<1x4xf32>
|
||||
return %mul: tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_non_cst
|
||||
func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %arg0 : tensor<2x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %arg0 : tensor<2x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
|
||||
|
||||
%q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
|
||||
return %mul: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_non_4x
|
||||
func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<2x5xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<2x5xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x5xf32>
|
||||
|
||||
%w = constant dense<1.0> : tensor<2x5xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x5xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<2x5xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x5xf32>
|
||||
return %mul: tensor<2x5xf32>
|
||||
}
|
||||
|
@ -5,10 +5,10 @@ func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[dq]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[dq]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x2xf32>
|
||||
%w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32>
|
||||
%mul = xla_hlo.mul %arg0, %w : tensor<2x2xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %w : tensor<2x2xf32>
|
||||
return %mul: tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -17,14 +17,14 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
@ -0,0 +1,15 @@
|
||||
// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK: error: 'tf.MyCustomOp' op is neither a custom op nor a flex op
|
||||
// CHECK: error: failed while converting: 'main'
|
||||
// CHECK: Ops that need custom implementation (enabled via setting the -emit-custom-ops flag):
|
||||
// CHECK: tf.MyCustomOp {name = "MyCustomOp"}
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
%0 = "tfl.pseudo_const" () {name = "Const", value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
%1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE", name = "mul"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%2 = "tf.MyCustomOp"(%1, %0) {name = "MyCustomOp"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%3 = "tfl.exp"(%2) {name = "exp"} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return %3 : tensor<4xf32>
|
||||
}
|
@ -1,8 +1,9 @@
|
||||
// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s
|
||||
// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK: error: 'tf.Div' op is neither a custom op nor a flex op
|
||||
// CHECK: error: failed while converting: 'main'
|
||||
// CHECK: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): Div.
|
||||
// CHECK: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag):
|
||||
// CHECK: tf.Div {name = "div"}
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
|
@ -15,11 +15,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -20,16 +20,16 @@ limitations under the License.
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
|
@ -19,13 +19,13 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Parser.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
@ -25,9 +25,9 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
||||
// This transformation pass convert dense tensor to sparse format.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -21,12 +21,12 @@ limitations under the License.
|
||||
#include <cstdint>
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
@ -74,31 +74,31 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
|
||||
PatternRewriter& rewriter) const;
|
||||
|
||||
public:
|
||||
PatternMatchResult matchAndRewrite(Conv2dOpTy op,
|
||||
PatternRewriter& rewriter) const override;
|
||||
LogicalResult matchAndRewrite(Conv2dOpTy op,
|
||||
PatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
template <typename Conv2dOpTy>
|
||||
PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
Conv2dOpTy op, PatternRewriter& rewriter) const {
|
||||
// Make sure Conv2D has 'VALID' padding.
|
||||
if (op.template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
|
||||
return Pattern::matchFailure();
|
||||
return failure();
|
||||
}
|
||||
// Make sure dilations are all ones if set.
|
||||
const ArrayAttr& dilations =
|
||||
op.template getAttrOfType<ArrayAttr>("dilations");
|
||||
if (dilations && !TFIntListIsAllOnes(dilations)) {
|
||||
return Pattern::matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
|
||||
// `Squeeze` op.
|
||||
Operation* prev_op = op.getOperation()->getPrevNode();
|
||||
if (!prev_op) return Pattern::matchFailure();
|
||||
if (!prev_op) return failure();
|
||||
|
||||
Operation* next_op = op.getOperation()->getNextNode();
|
||||
if (!next_op) return Pattern::matchFailure();
|
||||
if (!next_op) return failure();
|
||||
|
||||
TF::ExpandDimsOp expand_op;
|
||||
TF::SqueezeOp squeeze_op;
|
||||
@ -107,7 +107,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
|
||||
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
|
||||
// Expand/Squeeze op must come in pair.
|
||||
return Pattern::matchFailure();
|
||||
return failure();
|
||||
}
|
||||
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
|
||||
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
|
||||
@ -119,24 +119,24 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
(*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
|
||||
.getSExtValue();
|
||||
} else {
|
||||
return Pattern::matchFailure();
|
||||
return failure();
|
||||
}
|
||||
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
|
||||
auto squeeze_dims = squeeze_op.squeeze_dims();
|
||||
if (squeeze_dims.size() != 1 ||
|
||||
squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) {
|
||||
return Pattern::matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Update previous/next op pointer.
|
||||
prev_op = prev_op->getPrevNode();
|
||||
if (!prev_op) return Pattern::matchFailure();
|
||||
if (!prev_op) return failure();
|
||||
next_op = next_op->getNextNode();
|
||||
if (!next_op) return Pattern::matchFailure();
|
||||
if (!next_op) return failure();
|
||||
}
|
||||
|
||||
// SpaceToBatchND op.
|
||||
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure();
|
||||
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return failure();
|
||||
// TODO(b/149936532): Check `padding` input, currently ignored.
|
||||
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
|
||||
|
||||
@ -148,7 +148,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
if (llvm::isa<TF::PadOp>(next_op)) {
|
||||
pad_op = llvm::cast<TF::PadOp>(next_op);
|
||||
next_op = next_op->getNextNode();
|
||||
if (!next_op) return Pattern::matchFailure();
|
||||
if (!next_op) return failure();
|
||||
}
|
||||
|
||||
// BatchToSpaceND + BiasAdd.
|
||||
@ -160,8 +160,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
// Must be BiasAdd + BatchToSpaceND.
|
||||
biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
|
||||
next_op = next_op->getNextNode();
|
||||
if (!next_op || !llvm::isa<TF::BatchToSpaceNDOp>(next_op))
|
||||
return Pattern::matchFailure();
|
||||
if (!next_op || !llvm::isa<TF::BatchToSpaceNDOp>(next_op)) return failure();
|
||||
bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
|
||||
} else if (llvm::isa<TF::BatchToSpaceNDOp>(next_op)) {
|
||||
// BatchToSpaceND + (optional) BiasAdd.
|
||||
@ -172,12 +171,12 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
final_op_is_bts = false;
|
||||
}
|
||||
} else {
|
||||
return Pattern::matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
|
||||
stb_op.block_shape(), bts_op.block_shape(), rewriter);
|
||||
if (!dilations_attr.hasValue()) return Pattern::matchFailure();
|
||||
if (!dilations_attr.hasValue()) return failure();
|
||||
op.setAttr("dilations", dilations_attr.getValue());
|
||||
|
||||
// Padding is set to 'SAME' when `stb_op` has non-zero paddings.
|
||||
@ -228,7 +227,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
}
|
||||
|
||||
stb_op.getResult().dropAllUses();
|
||||
return Pattern::matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename Conv2dOpTy>
|
||||
|
@ -21,26 +21,26 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
|
@ -15,23 +15,23 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -28,17 +28,17 @@ limitations under the License.
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -98,12 +98,12 @@ bool HasSameStaticShapes(Operation* op) {
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
|
||||
|
||||
#define DECL_CONVERT_OP(tf_op) \
|
||||
struct ConvertTF##tf_op##Op : public RewritePattern { \
|
||||
explicit ConvertTF##tf_op##Op(MLIRContext* context) \
|
||||
: RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
|
||||
PatternMatchResult matchAndRewrite( \
|
||||
Operation* op, PatternRewriter& rewriter) const override; \
|
||||
#define DECL_CONVERT_OP(tf_op) \
|
||||
struct ConvertTF##tf_op##Op : public RewritePattern { \
|
||||
explicit ConvertTF##tf_op##Op(MLIRContext* context) \
|
||||
: RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
|
||||
LogicalResult matchAndRewrite(Operation* op, \
|
||||
PatternRewriter& rewriter) const override; \
|
||||
}
|
||||
|
||||
// TODO(antiagainst): Define this pattern in a table-driven manner once variadic
|
||||
@ -127,14 +127,14 @@ DECL_CONVERT_OP(BroadcastTo);
|
||||
|
||||
#undef DECL_CONVERT_OP
|
||||
|
||||
PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto random_uniform_op = cast<TF::RandomUniformOp>(op);
|
||||
if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
if (!random_uniform_op.dtype().isF32()) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
typedef tensorflow::random::UniformDistribution<
|
||||
tensorflow::random::PhiloxRandom, float>
|
||||
@ -149,7 +149,7 @@ PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite(
|
||||
random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) {
|
||||
if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) {
|
||||
if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
num_elements = output_type.getNumElements();
|
||||
size_t offset = 0;
|
||||
@ -165,13 +165,13 @@ PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite(
|
||||
}
|
||||
auto output_data = DenseFPElementsAttr::get(output_type, data);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFConcatOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
||||
|
||||
@ -180,17 +180,17 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
|
||||
// Extract axis attribute from constant concat_dims tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
|
||||
op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis),
|
||||
fused_activation_function);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
|
||||
|
||||
@ -198,15 +198,14 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
auto output_type = tf_concat_op.output().getType();
|
||||
// Extract axis attribute from constant axis tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
|
||||
return matchFailure();
|
||||
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
rewriter.replaceOpWithNewOp<ConcatenationOp>(
|
||||
op, output_type, values, ExtractSingleElementAsInteger(axis),
|
||||
fused_activation_function);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// The following is effectively:
|
||||
@ -215,11 +214,11 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
// ConstBoolAttrTrue:$transpose_b),
|
||||
// (TFL_FullyConnectedOp:$__0 $a, $b,
|
||||
// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>;
|
||||
PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFMatMulOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_matmul_op = cast<TF::MatMulOp>(op);
|
||||
if (tf_matmul_op.transpose_a()) return matchFailure();
|
||||
if (!tf_matmul_op.transpose_b()) return matchFailure();
|
||||
if (tf_matmul_op.transpose_a()) return failure();
|
||||
if (!tf_matmul_op.transpose_b()) return failure();
|
||||
|
||||
Type output_type = tf_matmul_op.getResult().getType();
|
||||
// TODO(jpienaar): Follow up post shuffle discussion.
|
||||
@ -230,10 +229,10 @@ PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
|
||||
op->getOperand(1), no_input, rewriter.getStringAttr("NONE"),
|
||||
rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOp(op, {fc_op.getResult(0)});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFPackOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFPackOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_pack_op = cast<TF::PackOp>(op);
|
||||
|
||||
@ -245,10 +244,10 @@ PatternMatchResult ConvertTFPackOp::matchAndRewrite(
|
||||
|
||||
rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count,
|
||||
axis);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_reshape_op = cast<TF::ReshapeOp>(op);
|
||||
|
||||
@ -269,10 +268,10 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
|
||||
input, shape);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFSplitOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFSplitOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_split_op = cast<TF::SplitOp>(op);
|
||||
|
||||
@ -284,10 +283,10 @@ PatternMatchResult ConvertTFSplitOp::matchAndRewrite(
|
||||
rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, output_types,
|
||||
tf_split_op.split_dim(),
|
||||
tf_split_op.value(), num_split);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFSplitVOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_splitv_op = cast<TF::SplitVOp>(op);
|
||||
|
||||
@ -299,7 +298,7 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
|
||||
rewriter.replaceOpWithNewOp<TFL::SplitVOp>(
|
||||
op, output_types, tf_splitv_op.value(), tf_splitv_op.size_splits(),
|
||||
tf_splitv_op.split_dim(), num_split);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
|
||||
@ -330,7 +329,7 @@ Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
|
||||
return rewriter.create<ConstantOp>(op->getLoc(), type, attr);
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
|
||||
auto ranked_input_type =
|
||||
@ -352,7 +351,7 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
tf_strided_slice_op.new_axis_mask().getSExtValue()),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.shrink_axis_mask().getSExtValue()));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
int num_input_dims = ranked_input_type.getRank();
|
||||
@ -382,10 +381,10 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
tf_strided_slice_op.new_axis_mask().getSExtValue()),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.shrink_axis_mask().getSExtValue()));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFUnpackOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
||||
|
||||
@ -397,7 +396,7 @@ PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
|
||||
auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue());
|
||||
|
||||
rewriter.replaceOpWithNewOp<UnpackOp>(op, output_types, input, num, axis);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute
|
||||
@ -449,25 +448,25 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
|
||||
return true;
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFMatrixDiagV2Op::matchAndRewrite(
|
||||
LogicalResult ConvertTFMatrixDiagV2Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV2Op>(op, &rewriter))
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
return success();
|
||||
return failure();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
|
||||
LogicalResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV3Op>(op, &rewriter))
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
return success();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// TF Lite doesn't support Assert, we just drop the assert from the graph.
|
||||
PatternMatchResult ConvertTFAssertOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFAssertOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
|
||||
@ -545,7 +544,7 @@ StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
|
||||
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op);
|
||||
|
||||
@ -553,7 +552,7 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
&rewriter, op->getLoc(),
|
||||
tf_reciprocal_op.x().getType().cast<ShapedType>(), 1);
|
||||
if (!status_or_const_op.ok()) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
@ -562,10 +561,10 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
rewriter.replaceOpWithNewOp<TFL::DivOp>(op, status_or_const_op.ValueOrDie(),
|
||||
tf_reciprocal_op.x(),
|
||||
fused_activation_function);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFBroadcastToOp::matchAndRewrite(
|
||||
LogicalResult ConvertTFBroadcastToOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
|
||||
auto element_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
|
||||
@ -574,7 +573,7 @@ PatternMatchResult ConvertTFBroadcastToOp::matchAndRewrite(
|
||||
auto status_or_const_op =
|
||||
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1);
|
||||
if (!status_or_const_op.ok()) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto tfl_fill_op = rewriter.create<TFL::FillOp>(
|
||||
@ -587,7 +586,7 @@ PatternMatchResult ConvertTFBroadcastToOp::matchAndRewrite(
|
||||
rewriter.replaceOpWithNewOp<TFL::MulOp>(
|
||||
op, output_type, tf_broadcast_to_op.input(), tfl_fill_op,
|
||||
fused_activation_function);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Legalize unidirectional sequence lstm.
|
||||
@ -595,11 +594,11 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
|
||||
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
|
||||
: RewritePattern(kUnidirectionalSequenceLstm, 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto tflite_indices_attr =
|
||||
op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
|
||||
if (!tflite_indices_attr) return matchFailure();
|
||||
if (!tflite_indices_attr) return failure();
|
||||
|
||||
SmallVector<int64_t, 20> tflite_indices;
|
||||
for (auto index_attr : tflite_indices_attr.getValue()) {
|
||||
@ -654,7 +653,7 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
|
||||
// Rewire the output.
|
||||
op->getResult(2).replaceAllUsesWith(lstm_op.getResult());
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -663,24 +662,24 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
|
||||
explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context)
|
||||
: RewritePattern(kUnidirectionalSequenceRnn, 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto tflite_indices_attr =
|
||||
op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
|
||||
if (!tflite_indices_attr) return matchFailure();
|
||||
if (!tflite_indices_attr) return failure();
|
||||
|
||||
if (op->getNumOperands() != 5) {
|
||||
op->emitError()
|
||||
<< "We're expecting 5 inputs for UnidirectionalSequenceRNN, only "
|
||||
<< op->getNumOperands() << " provided";
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (op->getNumResults() != 2) {
|
||||
op->emitError()
|
||||
<< "We're expecting 2 inputs for UnidirectionalSequenceRNN, only "
|
||||
<< op->getNumResults() << " found";
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Populate inputs.
|
||||
@ -714,7 +713,7 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
|
||||
op->getResult(1).replaceAllUsesWith(rnn_op.getResult());
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
||||
|
||||
// Converts TF While to TFL While with single call in body and cond.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
@ -31,28 +31,28 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
@ -175,33 +175,33 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
|
||||
struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
LogicalResult matchAndRewrite(
|
||||
TF::ConstOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Verify that the opaque elements attribute contains tensor of type variant
|
||||
// and scalar shape. The variant type should hold a TensorList.
|
||||
auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
|
||||
if (!opaque_attr) return matchFailure();
|
||||
if (!opaque_attr) return failure();
|
||||
tensorflow::Tensor tensor;
|
||||
if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
|
||||
return matchFailure();
|
||||
if (tensor.dtype() != tensorflow::DT_VARIANT) return matchFailure();
|
||||
return failure();
|
||||
if (tensor.dtype() != tensorflow::DT_VARIANT) return failure();
|
||||
if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
const tensorflow::TensorList *list =
|
||||
tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
|
||||
if (!list) return matchFailure();
|
||||
if (!list) return failure();
|
||||
|
||||
// Verify output type is variant and contains exactly one ranked subtypes.
|
||||
auto variant_ty =
|
||||
getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
|
||||
if (!variant_ty) return matchFailure();
|
||||
if (!variant_ty) return failure();
|
||||
ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
|
||||
if (subtypes.size() != 1) return matchFailure();
|
||||
if (subtypes.size() != 1) return failure();
|
||||
RankedTensorType list_element_ty =
|
||||
subtypes.front().dyn_cast<RankedTensorType>();
|
||||
if (!list_element_ty) return matchFailure();
|
||||
if (!list_element_ty) return failure();
|
||||
|
||||
// Extract tensor elements for the TensorList and construct result type
|
||||
// based on the number of elements and element shape.
|
||||
@ -225,9 +225,9 @@ struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
|
||||
tensorflow::Tensor tensor(list->element_dtype,
|
||||
tensorflow::TensorShape(tf_shape));
|
||||
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
|
||||
if (!attr_or.ok()) return matchFailure();
|
||||
if (!attr_or.ok()) return failure();
|
||||
rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Extract individual tensor list element and combine them using the tf.Pack
|
||||
@ -237,14 +237,14 @@ struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
|
||||
values.reserve(tensors.size());
|
||||
for (const tensorflow::Tensor &tensor : tensors) {
|
||||
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
|
||||
if (!attr_or.ok()) return matchFailure();
|
||||
if (!attr_or.ok()) return failure();
|
||||
|
||||
auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
|
||||
values.push_back(value);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<TF::PackOp>(
|
||||
op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -264,7 +264,7 @@ struct ConvertTensorListSetItem
|
||||
// (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
|
||||
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
|
||||
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
|
||||
PatternMatchResult matchAndRewrite(
|
||||
LogicalResult matchAndRewrite(
|
||||
TF::TensorListSetItemOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
@ -311,7 +311,7 @@ struct ConvertTensorListSetItem
|
||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||
op, input.getType(), scalar_zero,
|
||||
ArrayRef<Value>({slice1, expanded_item, slice2}));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -330,7 +330,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
// Rewrites the original op into `tf.fill`. The result tensor shape is
|
||||
// [num_element, element_shape]. All the values in the result tensor will be
|
||||
// initialized to 0.
|
||||
PatternMatchResult matchAndRewrite(
|
||||
LogicalResult matchAndRewrite(
|
||||
OpT op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type dtype = op.element_dtype();
|
||||
@ -342,7 +342,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
|
||||
"transformation pass");
|
||||
return ConversionPattern::matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value element_shape = operands[0];
|
||||
@ -354,7 +354,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
op.emitError(
|
||||
"requires element_shape to be 1D tensor during TF Lite "
|
||||
"transformation pass");
|
||||
return ConversionPattern::matchFailure();
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
@ -434,7 +434,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
auto zero = rewriter.create<ConstantOp>(loc, zero_type, zero_attr);
|
||||
|
||||
rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
|
||||
return Pattern::matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -472,7 +472,7 @@ struct ConvertTensorListPushBack
|
||||
: public OpConversionPattern<TF::TensorListPushBackOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
LogicalResult matchAndRewrite(
|
||||
TF::TensorListPushBackOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input_handle = operands[0];
|
||||
@ -498,7 +498,7 @@ struct ConvertTensorListPushBack
|
||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||
op, result_type, scalar_zero,
|
||||
ArrayRef<Value>({input_handle, expanded_item}));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -516,7 +516,7 @@ struct ConvertTensorListResize
|
||||
: public OpConversionPattern<TF::TensorListResizeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
LogicalResult matchAndRewrite(
|
||||
TF::TensorListResizeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input_handle = operands[0];
|
||||
@ -582,7 +582,7 @@ struct ConvertTensorListResize
|
||||
/*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
|
||||
/*output_shapes=*/rewriter.getStrArrayAttr({"{}"}),
|
||||
/*is_stateless=*/rewriter.getBoolAttr(true));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
@ -660,14 +660,14 @@ struct ConvertTensorListGetItem
|
||||
: public OpConversionPattern<TF::TensorListGetItemOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
LogicalResult matchAndRewrite(
|
||||
TF::TensorListGetItemOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input = operands[0];
|
||||
Value index = operands[1];
|
||||
rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
|
||||
rewriter.getBoolAttr(true));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -675,7 +675,7 @@ struct ConvertTensorListLength
|
||||
: public OpConversionPattern<TF::TensorListLengthOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
LogicalResult matchAndRewrite(
|
||||
TF::TensorListLengthOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
@ -687,7 +687,7 @@ struct ConvertTensorListLength
|
||||
rewriter.replaceOpWithNewOp<TF::GatherOp>(
|
||||
op, op.getType(), shape, CreateI32SplatConst(loc, &rewriter, {}, 0),
|
||||
/*validate_indices=*/true_attr);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -695,7 +695,7 @@ struct ConvertTensorListStack
|
||||
: public OpConversionPattern<TF::TensorListStackOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
LogicalResult matchAndRewrite(
|
||||
TF::TensorListStackOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
@ -713,7 +713,7 @@ struct ConvertTensorListStack
|
||||
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
||||
// If no constant is spotted, just forward the operand.
|
||||
rewriter.replaceOp(op, {input});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
RankedTensorType shape_type =
|
||||
@ -726,20 +726,20 @@ struct ConvertTensorListStack
|
||||
RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
|
||||
rewriter.replaceOpWithNewOp<TF::ReshapeOp>(op, result_type, input,
|
||||
new_shape);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
LogicalResult matchAndRewrite(
|
||||
TF::IdentityOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input = operands[0];
|
||||
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
|
||||
op.getAttrs());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -804,7 +804,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
LogicalResult matchAndRewrite(
|
||||
TF::WhileOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
@ -828,7 +828,7 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
|
||||
UpdateFunctionTypes(cloned);
|
||||
|
||||
rewriter.replaceOp(op, cloned.getResults());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -31,14 +31,14 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -211,18 +211,17 @@ DenseElementsAttr GetShape(Value output_val) {
|
||||
struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
||||
using OpRewritePattern<TFL::AddOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFL::AddOp add_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(TFL::AddOp add_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Match Add.
|
||||
DenseElementsAttr added_value;
|
||||
Value constant_val = add_op.rhs();
|
||||
if (!matchPattern(constant_val, m_Constant(&added_value)))
|
||||
return matchFailure();
|
||||
if (!matchPattern(constant_val, m_Constant(&added_value))) return failure();
|
||||
|
||||
// Match Fully Connected.
|
||||
auto fc_op =
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
|
||||
if (!fc_op) return matchFailure();
|
||||
if (!fc_op) return failure();
|
||||
|
||||
// Check if the constant RHS is either 0D (scalar), or a 1D with
|
||||
// `{num_channels}` shape.
|
||||
@ -236,17 +235,17 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
||||
if (constant_val_type.getRank() == 0) {
|
||||
is_scalar_rhs = true;
|
||||
} else if (constant_val_type.getRank() != 1) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value filter = fc_op.filter();
|
||||
Value bias = fc_op.bias();
|
||||
ElementsAttr bias_value;
|
||||
const bool is_none_bias = bias.getType().isa<NoneType>();
|
||||
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return failure();
|
||||
|
||||
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Rewrite
|
||||
Location loc = fc_op.getLoc();
|
||||
@ -261,7 +260,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
||||
// Filter must be a `2D` tensor with `{num_channels, num_features}`
|
||||
// shape. The following check is rejecting unknown rank (-1).
|
||||
if (filter_type.getRank() != 2) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
int num_channels = filter_type.getShape()[0];
|
||||
|
||||
@ -297,7 +296,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
||||
/*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
|
||||
/*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -305,13 +304,13 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
||||
struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
|
||||
using OpRewritePattern<TFL::ReluOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(TFL::ReluOp relu_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *input = relu_op.getOperand().getDefiningOp();
|
||||
if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure();
|
||||
if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
|
||||
auto fully_connected_op = cast<FullyConnectedOp>(input);
|
||||
if (fully_connected_op.fused_activation_function() != "NONE")
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto new_activation_func = rewriter.getStringAttr("RELU");
|
||||
auto new_weights_format =
|
||||
@ -323,7 +322,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
|
||||
fully_connected_op.filter(), fully_connected_op.bias(),
|
||||
new_activation_func, new_weights_format, new_keep_num_dims);
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -332,25 +331,25 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
|
||||
struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFL::MulOp mul_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(TFL::MulOp mul_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Mul.
|
||||
DenseElementsAttr cst;
|
||||
Value constant_val = mul_op.rhs();
|
||||
if (!matchPattern(constant_val, m_Constant(&cst))) return matchFailure();
|
||||
if (!matchPattern(constant_val, m_Constant(&cst))) return failure();
|
||||
|
||||
// Fully Connected.
|
||||
auto fc_op =
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
|
||||
if (!fc_op) return matchFailure();
|
||||
if (!fc_op) return failure();
|
||||
Value filter = fc_op.filter();
|
||||
Value bias = fc_op.bias();
|
||||
ElementsAttr cst_tmp;
|
||||
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
|
||||
if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&cst_tmp)))
|
||||
return matchFailure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
||||
return failure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return failure();
|
||||
|
||||
// Broadcast the constant operand of Mul if it isn't compatible to the
|
||||
// filter input. We only support broadcasting the operand along the depth
|
||||
@ -365,7 +364,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
normalized_shape, cst.getType().getElementType()));
|
||||
Type new_type = new_cst.getType();
|
||||
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
auto new_op =
|
||||
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
|
||||
@ -393,7 +392,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
/*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
|
||||
/*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -425,36 +424,36 @@ template <typename AffineOpType>
|
||||
struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
|
||||
using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFL::MulOp mul_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(TFL::MulOp mul_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Mul. Required 1-D rhs for batch normalization.
|
||||
DenseElementsAttr gamma_cst;
|
||||
Value gamma = mul_op.rhs();
|
||||
if (!matchPattern(gamma, m_Constant(&gamma_cst))) return matchFailure();
|
||||
if (gamma_cst.getType().getRank() != 1) return matchFailure();
|
||||
if (!matchPattern(gamma, m_Constant(&gamma_cst))) return failure();
|
||||
if (gamma_cst.getType().getRank() != 1) return failure();
|
||||
|
||||
// Affine op
|
||||
Operation *mul_op_lhs = mul_op.lhs().getDefiningOp();
|
||||
auto fc_op = dyn_cast_or_null<AffineOpType>(mul_op_lhs);
|
||||
if (!fc_op) return matchFailure();
|
||||
if (!fc_op) return failure();
|
||||
Value filter = fc_op.filter();
|
||||
Value bias = fc_op.bias();
|
||||
|
||||
// QDQs
|
||||
auto dq_op = dyn_cast_or_null<TFL::DequantizeOp>(filter.getDefiningOp());
|
||||
if (!dq_op) return matchFailure();
|
||||
if (!dq_op) return failure();
|
||||
auto q_op =
|
||||
dyn_cast_or_null<TFL::QuantizeOp>(dq_op.input().getDefiningOp());
|
||||
if (!q_op) return matchFailure();
|
||||
if (!q_op) return failure();
|
||||
filter = q_op.input();
|
||||
|
||||
// weight constant
|
||||
ElementsAttr cst_tmp;
|
||||
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
|
||||
if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&cst_tmp)))
|
||||
return matchFailure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
||||
return failure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return failure();
|
||||
|
||||
// Broadcast the constant operand of Mul if it isn't compatible to the
|
||||
// filter input. We only support broadcasting the operand along the depth
|
||||
@ -469,7 +468,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
|
||||
auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst);
|
||||
broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
|
||||
} else {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Rewrite filter constant. Since the folder of TFL::MulOp couldn't
|
||||
@ -478,7 +477,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
|
||||
rewriter.create<TF::MulOp>(loc, filter, broadcasted_gamma).z();
|
||||
// Update the scale in the quantize op.
|
||||
auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst);
|
||||
if (!new_qtype) return matchFailure();
|
||||
if (!new_qtype) return failure();
|
||||
rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(q_op, new_qtype.getValue(),
|
||||
new_filter, new_qtype);
|
||||
|
||||
@ -491,7 +490,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
|
||||
|
||||
// Remove the tailing mul op.
|
||||
mul_op.replaceAllUsesWith(fc_op.getResult());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -504,20 +503,19 @@ template <typename AffineOpType>
|
||||
struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
using OpRewritePattern<AffineOpType>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineOpType fc_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineOpType fc_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Binary op.
|
||||
Operation *binary_op = fc_op.input().getDefiningOp();
|
||||
if (!binary_op || binary_op->getNumOperands() != 2)
|
||||
return this->matchFailure();
|
||||
if (!binary_op || binary_op->getNumOperands() != 2) return failure();
|
||||
// We only handle the cases the RHS is a scalar.
|
||||
// TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that
|
||||
// the constant operands are on the RHS, we need to consider LHS constant
|
||||
// operand if necessary.
|
||||
DenseFPElementsAttr cst;
|
||||
if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst)))
|
||||
return this->matchFailure();
|
||||
if (cst.getNumElements() != 1) return this->matchFailure();
|
||||
return failure();
|
||||
if (cst.getNumElements() != 1) return failure();
|
||||
APFloat cst_value = *cst.float_value_begin();
|
||||
|
||||
// Affine op.
|
||||
@ -527,21 +525,21 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
if (!matchPattern(filter, m_Constant(&filter_cst))) {
|
||||
// The filter maybe quantized, then we should set it to the real constant.
|
||||
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
|
||||
if (!dq) return this->matchFailure();
|
||||
if (!dq) return failure();
|
||||
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
|
||||
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
}
|
||||
filter = q.input();
|
||||
}
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&bias_cst)))
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
ShapedType filter_type = filter_cst.getType();
|
||||
|
||||
if (llvm::isa<AddOp>(binary_op) || llvm::isa<SubOp>(binary_op)) {
|
||||
auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
|
||||
if (padding && padding.getValue() != "VALID") return this->matchFailure();
|
||||
if (padding && padding.getValue() != "VALID") return failure();
|
||||
|
||||
// The fusion of add/sub is actually applying the following
|
||||
// transformation:
|
||||
@ -568,7 +566,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
bias_cst.float_value_begin(),
|
||||
bias_cst.float_value_end());
|
||||
} else {
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
int64_t flatten_index = 0;
|
||||
@ -610,9 +608,9 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
fc_op.setOperand(1, new_filter_op);
|
||||
}
|
||||
} else {
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
}
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
@ -638,18 +636,17 @@ struct ConvertTrivialTransposeOpToReshapeOp
|
||||
: public OpRewritePattern<TFL::TransposeOp> {
|
||||
using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFL::TransposeOp transpose_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto input_type = transpose_op.x().getType().cast<ShapedType>();
|
||||
auto output_type = transpose_op.y().getType().cast<ShapedType>();
|
||||
// It's possible to know if the transformation is safe only if the input
|
||||
// & output shapes are fully known and permutation is a constant.
|
||||
if (!input_type.hasStaticShape() || !output_type.hasStaticShape())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
Value perm = transpose_op.perm();
|
||||
DenseElementsAttr perm_values_attr;
|
||||
if (!matchPattern(perm, m_Constant(&perm_values_attr)))
|
||||
return matchFailure();
|
||||
if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure();
|
||||
|
||||
auto input_shape = input_type.getShape();
|
||||
SmallVector<int64_t, 8> perm_values;
|
||||
@ -674,7 +671,7 @@ struct ConvertTrivialTransposeOpToReshapeOp
|
||||
}
|
||||
}
|
||||
if (old_major_index_ordering != new_major_index_ordering) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Rewrite.
|
||||
@ -693,7 +690,7 @@ struct ConvertTrivialTransposeOpToReshapeOp
|
||||
rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(
|
||||
transpose_op, transpose_op.y().getType(), transpose_op.x(), new_shape);
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -17,15 +17,15 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -75,14 +75,14 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
explicit FoldIfOp(MLIRContext* context, FuncSet* inlined_funcs)
|
||||
: OpRewritePattern<TF::IfOp>(context), inlined_funcs_(inlined_funcs) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(TF::IfOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
LogicalResult matchAndRewrite(TF::IfOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
// This pattern is restricted to if ops in functions with exactly one block
|
||||
// and therefore one terminator op. So, that function return type can be
|
||||
// updated if operands' shapes change after inlining. Without this
|
||||
// restriction, it would require tensor cast ops.
|
||||
FuncOp parent_op = op.getParentOfType<FuncOp>();
|
||||
if (parent_op.getBlocks().size() != 1) return matchFailure();
|
||||
if (parent_op.getBlocks().size() != 1) return failure();
|
||||
|
||||
// Find the then and else branch functions.
|
||||
SymbolTable table(op.getParentOfType<ModuleOp>());
|
||||
@ -98,18 +98,18 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
inlined_funcs_->insert(then_branch);
|
||||
inlined_funcs_->insert(else_branch);
|
||||
rewriter.eraseOp(op.getOperation());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Extract the constant cond value.
|
||||
DenseElementsAttr cond;
|
||||
if (!matchPattern(op.cond(), m_Constant(&cond))) return matchFailure();
|
||||
if (!matchPattern(op.cond(), m_Constant(&cond))) return failure();
|
||||
|
||||
// TODO(hinsu): Handle constants that are not scalar booleans.
|
||||
auto cond_type = cond.getType().dyn_cast<RankedTensorType>();
|
||||
if (!cond_type || !cond_type.getShape().equals({}) ||
|
||||
!cond_type.getElementType().isInteger(/*width=*/1))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Identify the branch to inline.
|
||||
bool cond_value = (*cond.int_value_begin()).getSExtValue();
|
||||
@ -118,7 +118,7 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
// Make sure that the function has exactly one block to simplify inlining.
|
||||
// TFLite doesn't use control flow with blocks so functions with more than
|
||||
// one blocks are not encountered in practice.
|
||||
if (func.getBody().getBlocks().size() != 1) return matchFailure();
|
||||
if (func.getBody().getBlocks().size() != 1) return failure();
|
||||
|
||||
BlockAndValueMapping mapper;
|
||||
for (int i = 0, e = func.getNumArguments(); i != e; ++i)
|
||||
@ -149,7 +149,7 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
// of the function.
|
||||
inlined_funcs_->insert(then_branch);
|
||||
inlined_funcs_->insert(else_branch);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
// This transformation pass applies some clean up steps after quantization.
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
@ -23,21 +23,21 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||
|
@ -22,11 +22,11 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
|
@ -38,17 +38,17 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
|
||||
@ -121,12 +121,12 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
MLIRContext *ctx)
|
||||
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(TFFakeQuantOp tf_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We don't want to insert quantize/dequantize if the quantize op exists.
|
||||
auto res = tf_op.outputs();
|
||||
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin()))
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
|
||||
// Extract the min/max constant values from the operands. We also consider
|
||||
// a special case that there are tf.Identity ops between the min/max
|
||||
@ -137,8 +137,8 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
min = id1.input();
|
||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
|
||||
max = id2.input();
|
||||
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
|
||||
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
|
||||
if (!matchPattern(min, m_Constant(&min_value))) return failure();
|
||||
if (!matchPattern(max, m_Constant(&max_value))) return failure();
|
||||
|
||||
int quant_dim = -1;
|
||||
if (PerAxis) {
|
||||
@ -155,7 +155,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
TypeAttr qtype = quant::GetQuantizedTypeAttr(
|
||||
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
|
||||
narrow_range, /*is_signed=*/false);
|
||||
if (!qtype) this->matchFailure();
|
||||
if (!qtype) failure();
|
||||
|
||||
// Finally, use the quantization parameter to create the quantize and
|
||||
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
|
||||
@ -168,7 +168,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -208,8 +208,8 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
: RewritePattern(TFConvOpType::getOperationName(), 1, context),
|
||||
intAttrOne(Builder(context).getI32IntegerAttr(1)) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Assumes TensorFlow convolution op is already verified to be
|
||||
// in valid form.
|
||||
|
||||
@ -223,10 +223,10 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
TFConvOpType tf_op = cast<TFConvOpType>(op);
|
||||
|
||||
if (!TFTypeIsFloatTensor(tf_op.input()) || !TFDataFormatIsNHWC(op))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
IntegerAttr height, width;
|
||||
if (!TFIntListIs1XY1(op, "strides", &height, &width)) return matchFailure();
|
||||
if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure();
|
||||
|
||||
ConvertTFConvOpMatchState state;
|
||||
state.stride_height = height;
|
||||
@ -242,14 +242,14 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
state.dilation_width_factor = intAttrOne;
|
||||
}
|
||||
|
||||
if (!TFPaddingIsSameOrValid(op, &state.padding)) return matchFailure();
|
||||
if (!TFPaddingIsSameOrValid(op, &state.padding)) return failure();
|
||||
|
||||
// Additionally, we require the filter operand to be of 4-D tensor type so
|
||||
// that we can extract info from the shape (e.g., for constructing bias
|
||||
// tensor, for setting depth_multiplier attribute, etc.).
|
||||
auto filter = tf_op.filter();
|
||||
auto filter_type = filter.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!filter_type || filter_type.getRank() != 4) return matchFailure();
|
||||
if (!filter_type || filter_type.getRank() != 4) return failure();
|
||||
|
||||
// TensorFlow convolution op only has two inputs, while the TFLite one has
|
||||
// three, with the bias vector marked as optional. However, TOCO has a
|
||||
@ -274,7 +274,7 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
bias);
|
||||
|
||||
rewriter.replaceOp(op, conv_op.getResult());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
const IntegerAttr intAttrOne;
|
||||
@ -418,8 +418,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
explicit ConvertTFStridedSlice(MLIRContext *context)
|
||||
: RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
|
||||
|
||||
PatternMatchResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
|
||||
PatternRewriter &rewriter) const {
|
||||
LogicalResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
|
||||
PatternRewriter &rewriter) const {
|
||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||
|
||||
// Insert a new reshape op.
|
||||
@ -474,11 +474,11 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
rewriter.getI64IntegerAttr(0),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.shrink_axis_mask()));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
|
||||
PatternRewriter &rewriter) const {
|
||||
LogicalResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
|
||||
PatternRewriter &rewriter) const {
|
||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||
|
||||
DenseIntElementsAttr begin_dense_elem_attr;
|
||||
@ -486,7 +486,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
|
||||
if (!begin_ranked_attr_type ||
|
||||
!matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
DenseIntElementsAttr end_dense_elem_attr;
|
||||
@ -494,7 +494,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
|
||||
if (!end_ranked_attr_type ||
|
||||
!matchPattern(end, m_Constant(&end_dense_elem_attr))) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
DenseIntElementsAttr stride_dense_elem_attr;
|
||||
@ -503,7 +503,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
stride.getType().dyn_cast<RankedTensorType>();
|
||||
if (!stride_ranked_attr_type ||
|
||||
!matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value input = strided_slice_op.input();
|
||||
@ -516,7 +516,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
const ArrayRef<int64_t> begin_shape = begin_type.getShape();
|
||||
const int begin_dim = begin_shape.size();
|
||||
|
||||
if (begin_dim != 1) return matchFailure();
|
||||
if (begin_dim != 1) return failure();
|
||||
|
||||
const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
|
||||
|
||||
@ -586,11 +586,11 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
strided_slice_op.new_axis_mask()),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.shrink_axis_mask()));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(renjieliu): Consider expand the transformation for shrink
|
||||
// mask as well.
|
||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||
@ -606,7 +606,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
if (ellipsis_mask != 0) {
|
||||
return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter);
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -19,17 +19,17 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/Functional.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user