Merge branch 'master' into interface_16x8

This commit is contained in:
Elena Zhelezina 2020-03-24 10:38:37 +00:00 committed by GitHub
commit fe806513ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1245 changed files with 68478 additions and 21106 deletions

View File

@ -356,7 +356,15 @@ build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/to
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --config=rbe_linux
build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
@ -365,13 +373,20 @@ build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda1
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_CUDA=1
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc

View File

@ -138,3 +138,7 @@ load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
bazel_version_repository(name = "bazel_version")
load("//third_party/googleapis:repository_rules.bzl", "config_googleapis")
config_googleapis()

View File

@ -1155,7 +1155,7 @@ def set_trisycl_include_dir(environ_cp):
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def system_specific_test_config(env):
def system_specific_test_config(environ_cp):
"""Add default build and test flags required for TF tests to bazelrc."""
write_to_bazelrc('test --flaky_test_attempts=3')
write_to_bazelrc('test --test_size_filters=small,medium')
@ -1171,14 +1171,14 @@ def system_specific_test_config(env):
test_only_filters = ['-oss_serial']
if is_windows():
test_and_build_filters.append('-no_windows')
if env.get('TF_NEED_CUDA', None) == '1':
if environ_cp.get('TF_NEED_CUDA', None) == '1':
test_and_build_filters += ['-no_windows_gpu', '-no_gpu']
else:
test_and_build_filters.append('-gpu')
elif is_macos():
test_and_build_filters += ['-gpu', '-nomac', '-no_mac']
elif is_linux():
if env.get('TF_NEED_CUDA', None) == '1':
if environ_cp.get('TF_NEED_CUDA', None) == '1':
test_and_build_filters.append('-no_gpu')
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
else:
@ -1522,7 +1522,7 @@ def main():
create_android_ndk_rule(environ_cp)
create_android_sdk_rule(environ_cp)
system_specific_test_config(os.environ)
system_specific_test_config(environ_cp)
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
if environ_cp.get('TF_CONFIGURE_IOS') == '1':

View File

@ -702,6 +702,7 @@ tf_cc_shared_object(
"//tensorflow/c:exported_symbols.lds",
"//tensorflow/c:version_script.lds",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:tensorflow",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
],

View File

@ -651,16 +651,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->worker_env()->session_mgr->UpdateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
TF_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get());
LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
grpc_server->worker_env(), std::move(remote_eager_workers),
added_workers, removed_workers, context_id, r, device_mgr,
keep_alive_secs, cluster_flr));
added_workers, removed_workers, context_id, r));
}
#undef LOG_AND_RETURN_IF_ERROR

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // TF:dlpack
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor.h"

View File

@ -24,9 +24,14 @@ cc_library(
srcs = [
"modular_filesystem.cc",
"modular_filesystem_registration.cc",
],
hdrs = [
"modular_filesystem.h",
"modular_filesystem_registration.h",
],
hdrs = ["modular_filesystem.h"],
# TODO(mihaimaruseac): Visibility should be more restrictive once we
# convert to modular filesystems everywhere
visibility = ["//visibility:public"],
deps = [
":filesystem_interface",
"//tensorflow/c:tf_status_helper",

View File

@ -440,7 +440,25 @@ Status ModularWritableFile::Tell(int64* position) {
}
Status RegisterFilesystemPlugin(const std::string& dso_path) {
return filesystem_registration::RegisterFilesystemPluginImpl(dso_path);
// Step 1: Load plugin
Env* env = Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
// Step 3: Call `TF_InitPlugin`
TF_FilesystemPluginInfo info;
memset(&info, 0, sizeof(info));
auto TF_InitPlugin =
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
TF_InitPlugin(&info);
// Step 4: Do the actual registration
return filesystem_registration::RegisterFilesystemPluginImpl(&info);
}
} // namespace tensorflow

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/env.h"
@ -304,40 +303,22 @@ static Status ValidatePluginMemoryRoutines(
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
// Step 1: Load plugin
Env* env = Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info) {
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(info));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
// Step 3: Call `TF_InitPlugin`
TF_FilesystemPluginInfo info;
memset(&info, 0, sizeof(info));
auto TF_InitPlugin =
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
TF_InitPlugin(&info);
// Step 4: Ensure plugin provides the memory management functions.
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(&info));
// Step 5: Validate and register all filesystems
// Validate and register all filesystems
// Try to register as many filesystems as possible.
// Free memory once we no longer need it
Status status;
for (int i = 0; i < info.num_schemes; i++) {
status.Update(ValidateAndRegisterFilesystems(&info, i));
info.plugin_memory_free(info.ops[i].scheme);
info.plugin_memory_free(info.ops[i].filesystem_ops);
info.plugin_memory_free(info.ops[i].random_access_file_ops);
info.plugin_memory_free(info.ops[i].writable_file_ops);
info.plugin_memory_free(info.ops[i].read_only_memory_region_ops);
for (int i = 0; i < info->num_schemes; i++) {
status.Update(ValidateAndRegisterFilesystems(info, i));
info->plugin_memory_free(info->ops[i].scheme);
info->plugin_memory_free(info->ops[i].filesystem_ops);
info->plugin_memory_free(info->ops[i].random_access_file_ops);
info->plugin_memory_free(info->ops[i].writable_file_ops);
info->plugin_memory_free(info->ops[i].read_only_memory_region_ops);
}
info.plugin_memory_free(info.ops);
info->plugin_memory_free(info->ops);
return status;
}

View File

@ -15,12 +15,17 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace filesystem_registration {
Status RegisterFilesystemPluginImpl(const std::string& dso_path);
// Implementation for filesystem registration
//
// Don't call this directly. Instead call `RegisterFilesystemPlugin`.
// Exposed only for static registration of local filesystems.
Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info);
} // namespace filesystem_registration
} // namespace tensorflow

View File

@ -19,6 +19,7 @@ tf_cc_shared_object(
cc_library(
name = "posix_filesystem_impl",
srcs = ["posix_filesystem.cc"],
hdrs = ["posix_filesystem.h"],
deps = [
":posix_filesystem_helper",
"//tensorflow/c:tf_status",
@ -26,6 +27,20 @@ cc_library(
],
)
# Since building pip package and API tests require a filesystem, we provide a
# static registration target that they should link against.
cc_library(
name = "posix_filesystem_static",
srcs = ["posix_filesystem_static.cc"],
visibility = ["//visibility:public"],
deps = [
":posix_filesystem_impl",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/c/experimental/filesystem:modular_filesystem",
],
alwayslink = 1,
)
# Library implementing helper functionality, so that the above only contains
# the API implementation for modular filesystems.
cc_library(

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h"
#include <dirent.h>
#include <errno.h>
#include <fcntl.h>
@ -24,7 +26,6 @@ limitations under the License.
#include <sys/stat.h>
#include <unistd.h>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h"
#include "tensorflow/c/tf_status.h"

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
// Initialize the POSIX filesystem.
//
// In general, the `TF_InitPlugin` symbol doesn't need to be exposed in a header
// file, since the plugin registration will look for the symbol in the DSO file
// that provides the filesystem functionality. However, the POSIX filesystem
// needs to be statically registered in some tests and utilities for building
// the API files at the time of creating the pip package. Hence, we need to
// expose this function so that this filesystem can be statically registered
// when needed.
void TF_InitPlugin(TF_FilesystemPluginInfo* info);
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h"
#include "tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h"
namespace tensorflow {
// Register the POSIX filesystems statically.
// Return value will be unused
bool StaticallyRegisterLocalFilesystems() {
TF_FilesystemPluginInfo info;
TF_InitPlugin(&info);
Status status = filesystem_registration::RegisterFilesystemPluginImpl(&info);
if (!status.ok()) {
VLOG(0) << "Static POSIX filesystem could not be registered: " << status;
return false;
}
return true;
}
// Perform the actual registration
static bool unused = StaticallyRegisterLocalFilesystems();
} // namespace tensorflow

View File

@ -632,6 +632,7 @@ tf_gen_op_wrappers_cc(
"tpu_configuration_ops",
"tpu_cross_replica_ops",
"tpu_embedding_ops",
"tpu_embedding_load_retrieve_ops",
"tpu_functional_ops",
"tpu_heartbeat_ops",
"tpu_host_compute_ops",

View File

@ -521,15 +521,15 @@ Status SymbolicGradientBuilder::AddGradients() {
// gradient function to the src node/output to which it should be
// backpropped. Maybe grad functions can return a vector of Output pairs to
// make this association explicit.
size_t dx_index = 0;
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) continue;
if (dx_index == dx.size()) {
int dx_index = e->dst_input();
if (dx_index >= dx.size()) {
return errors::Internal(
"Invalid gradient output index: ", dx_index, " size: ", dx.size());
}
TF_RETURN_IF_ERROR(
BackpropAlongEdge(dx[dx_index++], {e->src(), e->src_output()}));
BackpropAlongEdge(dx[dx_index], {e->src(), e->src_output()}));
}
}

View File

@ -503,6 +503,42 @@ TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) {
EXPECT_EQ(grad_result[0].flat<float>()(0), 17610.0f);
}
TEST_F(GradientsTest, AddSymbolicGradientsTest) {
Scope scope = Scope::NewRootScope();
for (int cnt = 0; cnt < 100; ++cnt) {
int N = 5 + rand() % 10;
// Construct forward graph.
OutputList inputs;
for (int i = 0; i < N; ++i) {
auto a = Const(scope, i, {1});
inputs.push_back(a);
}
auto pack = Stack(scope, inputs);
TF_ASSERT_OK(scope.status());
// Construct grad inputs.
OutputList output_grads;
Tensor ts(DT_INT32, {N, 1});
auto v = ts.matrix<int32>();
for (int i = 0; i < N; ++i) {
v(i, 0) = i;
}
auto dy = Const(scope, ts);
output_grads.push_back(dy);
// Call AddSymbolicGradients.
std::vector<Output> grad_outputs;
TF_ASSERT_OK(AddSymbolicGradients(scope, {pack.output}, inputs,
output_grads, &grad_outputs));
ClientSession session((scope));
std::vector<Tensor> in_grad;
TF_ASSERT_OK(session.Run(grad_outputs, &in_grad));
for (int i = 0; i < N; ++i) {
test::ExpectTensorEqual<int>(in_grad[i], test::AsTensor<int>({i}, {1}));
}
}
}
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
// a single nodes output.

View File

@ -170,7 +170,9 @@ Status GenArgMethods(const tf2xla::Config& config,
const xla::ProgramShapeProto& ps,
const CompileResult& compile_result, string* methods) {
size_t num_args = ps.parameters_size();
if (config.feed_size() + config.variable_size() != num_args) {
// feed_size() + variable_size() is the maximum number of args as an
// implementation may not create an argument for an unused variable.
if (config.feed_size() + config.variable_size() < num_args) {
return errors::InvalidArgument(
"mismatch between feed_size(", config.feed_size(), ")+variable_size(",
config.variable_size(), ") and num_args(", num_args, ")");

View File

@ -154,7 +154,7 @@ static void CompareWithGoldenFile(
// To update the golden file, flip update_golden to true and run the
// following:
// bazel test --test_strategy=local \
// third_party/tensorflow/compiler/aot:codegen_test
// "third_party/tensorflow/compiler/aot:codegen_test"
const bool update_golden = false;
string golden_file_name =
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);

View File

@ -157,6 +157,7 @@ def tftop_k(_):
def tfvariable_readonly(_):
x = variables.Variable(1000.0, name='x')
unused_y = variables.Variable(1000.0, name='y')
old_x = x.value()
with ops.control_dependencies([old_x]):
new_value = math_ops.add(old_x, 42.0)

View File

@ -10,3 +10,11 @@ variable {
type: DT_FLOAT
readonly: true
}
variable {
node_name: "y"
shape {
}
type: DT_FLOAT
readonly: true
}

View File

@ -338,6 +338,7 @@ cc_library(
deps = [
":xla_activity_listener",
":xla_activity_proto_cc",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",

View File

@ -299,7 +299,11 @@ Status ComputeIncompatibleResourceOperationPairs(
result->push_back({incoming_op.first, n->id()});
}
resource_op_set->Add({n->id(), *op_kind});
// Some graphs might have a lot of 'kRead' kinds, but they are always safe
// for incoming ops, so not storing them might save a lot of memory.
if (op_kind != XlaResourceOpKind::kRead) {
resource_op_set->Add({n->id(), *op_kind});
}
}
if (vlog) {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -33,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
@ -40,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
@ -273,8 +276,30 @@ Status XlaCompilationCache::CompileSingleOp(
const NodeDef& node_def = ctx->op_kernel().def();
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) {
return arg.kind == XlaCompiler::Argument::kParameter;
});
const ConfigProto* config = ctx->function_library()->config_proto();
bool use_mlir = config && config->experimental().enable_mlir_bridge();
// Use MLIR bridge if all the arguments are parameters.
// TODO(hinsu): Support other argument types instead of silently falling
// back to the XLA compiler.
if (!are_params || !use_mlir) {
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
}
absl::InlinedVector<TensorShape, 4> arg_shapes;
arg_shapes.reserve(args.size());
for (const XlaCompiler::Argument& arg : args) {
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
}
GraphDebugInfo debug_info;
return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()},
compile_options.use_tuple_arg,
*options.flib_def, debug_info,
options.shape_representation_fn, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,

View File

@ -78,7 +78,9 @@ class XlaCompilationCache : public ResourceBase {
xla::LocalExecutable** out_executable);
// As above, but calls XlaCompiler::CompileSingleOp instead of
// XlaCompiler::CompileFunction.
// XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto
// in OpKernelContext, then uses MLIR bridge for compilation instead of
// XlaCompiler, if possible.
Status CompileSingleOp(
const XlaCompiler::Options& options,
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,

View File

@ -71,6 +71,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/xla:lhlo",
@ -88,7 +89,7 @@ cc_library(
"//tensorflow/compiler/mlir/xla:xla_lower",
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
"//tensorflow/compiler/mlir/xla:xla_test_passes",
"@llvm-project//mlir:AffineOps",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:QuantOps",
],
)

View File

@ -26,6 +26,7 @@ package_group(
filegroup(
name = "tensorflow_lite_ops_td_files",
srcs = [
"experimental/tfl_hardware_interfaces.td",
"ir/tfl_op_interfaces.td",
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
@ -204,6 +205,8 @@ cc_library(
cc_library(
name = "tensorflow_lite",
srcs = [
"experimental/estimators/estimator.h",
"experimental/estimators/gpu_estimator.h.inc",
"ir/tfl_ops.cc",
"ir/tfl_ops.cc.inc",
"ir/tfl_ops.h.inc",
@ -213,6 +216,7 @@ cc_library(
"utils/attribute_utils.cc",
],
hdrs = [
"experimental/estimators/hardware.h",
"ir/tfl_ops.h",
"transforms/passes.h",
"utils/attribute_utils.h",
@ -222,7 +226,6 @@ cc_library(
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support",
@ -419,7 +422,9 @@ cc_library(
],
deps = [
":tensorflow_lite",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -439,6 +444,7 @@ genrule(
srcs = [
"ir/tfl_ops.td",
"ir/tfl_op_interfaces.td",
"experimental/tfl_hardware_interfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
],
@ -551,14 +557,14 @@ cc_library(
cc_library(
name = "flatbuffer_translate_lib",
srcs = [
"flatbuffer_export.cc",
"flatbuffer_import.cc",
"flatbuffer_translate.cc",
"utils/convert_type.cc",
],
hdrs = [
"flatbuffer_export.h",
"flatbuffer_export_flags.h",
"flatbuffer_import.h",
"flatbuffer_translate.h",
"flatbuffer_translate_flags.h",
"utils/convert_type.h",
],
deps = [
@ -575,9 +581,10 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:status",
"//tensorflow/lite:framework",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util",
@ -598,15 +605,37 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
],
)
cc_library(
name = "flatbuffer_translate_registeration",
srcs = [
"flatbuffer_translate.cc",
],
deps = [
":flatbuffer_translate_lib",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LoopOpsTransforms",
"@llvm-project//mlir:MlirTranslateMain",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
],
alwayslink = 1,
)
tf_cc_binary(
name = "flatbuffer_translate",
deps = [
":flatbuffer_translate_lib",
"@llvm-project//mlir:LoopOpsTransforms",
"@llvm-project//mlir:MlirTranslateMain",
":flatbuffer_translate_registeration",
],
)
@ -644,10 +673,13 @@ filegroup(
tf_cc_binary(
name = "tf_tfl_translate",
srcs = [":tf_tfl_translate_main"],
srcs = [
":tf_tfl_translate_main",
],
deps = [
":common",
":flatbuffer_translate_lib",
":flatbuffer_translate_registeration",
":tensorflow_lite",
":tf_tfl_passes",
":tf_tfl_translate_cl_options",
@ -669,15 +701,18 @@ tf_cc_binary(
tf_cc_binary(
name = "mlir-tflite-runner",
srcs = ["mlir_tflite_runner.cc"],
srcs = [
"mlir_tflite_runner.cc",
],
deps = [
":flatbuffer_translate_lib",
":flatbuffer_translate_registeration",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//tensorflow/lite:framework",
"//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",

View File

@ -27,10 +27,10 @@ limitations under the License.
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
#include "mlir/TableGen/Format.h" // TF:llvm-project
#include "mlir/TableGen/Operator.h" // TF:llvm-project
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
#include "mlir/TableGen/Attribute.h" // from @llvm-project
#include "mlir/TableGen/Format.h" // from @llvm-project
#include "mlir/TableGen/Operator.h" // from @llvm-project
#include "mlir/TableGen/Predicate.h" // from @llvm-project
using llvm::DefInit;
using llvm::dyn_cast;
@ -442,28 +442,26 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
verify_ctx.withOp("top");
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
auto &value = op.getOperand(i);
// Skip from from first variadic operands for now. Else getOperand index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
}
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
auto &value = op.getResult(i);
// Skip from from first variadic results for now. Else getResult index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
}
auto &value = op.getOperand(i);
// Skip from from first variadic operands for now. Else getOperand index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
}
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
auto &value = op.getResult(i);
// Skip from from first variadic results for now. Else getResult index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
}
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
"operand");
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
"result");
os << " return mlir::success();\n}\n";
os << " return top.verify();\n}\n";
}
return false;

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <cstdarg>
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/lite/core/api/error_reporter.h"
namespace tflite {

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/experimental/estimators/hardware.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
template <typename Op, typename TargetHardware>
class TFLiteCostEstimator {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) {
llvm::errs() << "No defined support for op: "
<< op->getName().getStringRef().str();
return false;
}
};
// All ops on CPU are supported.
// TODO(karimnosseir): Only allow TFL ops in the "TFL_OP" param.
template <typename TFL_OP>
class TFLiteCostEstimator<TFL_OP, hardware::CPU> {
public:
// TODO(karimnosseir): Update and use table based method and lookup
// cost from a loadable table ?
static double GetCost(mlir::Operation* op) { return 0.0; }
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_
template <>
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATOR_H_

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
#include "tensorflow/core/framework/dataset.h"
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_HARDWARE_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_HARDWARE_H_
#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
namespace hardware {
// Empty classes that represents hardware types.
class CPU {};
class GPU {};
} // namespace hardware
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_HARDWARE_H_

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// WARNING: This Interface is experimental, DO NOT USE.
// This is the Target Hardware operation interfacea definition file
// for TensorFlow Lite.
#ifndef TFL_TARGET_HARDWARE_OP_INTERFACES
#define TFL_TARGET_HARDWARE_OP_INTERFACES
def TFL_CpuTargetOp : OpInterface<"CpuOpTargetInterface"> {
let description = [{
Interface for ops to run on CPU.
}];
let methods = [
InterfaceMethod<
[{Returns the cost of running this op on CPU.}],
// TODO(karimnosseir): Change to return Cost object instead.
"double", "GetOpCost", (ins "mlir::Operation*":$op_to_check), [{
// TODO(karimnosseir): Consider changing to another way that doesn't
// rely on template param name.
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::CPU>::GetCost(op_to_check);
}]
>,
InterfaceMethod<
[{Returns whether this op can be run on CPU.}],
"bool", "IsSupported", (ins "mlir::Operation*":$op_to_check), [{
// TODO(karimnosseir): Consider changing to another way that doesn't
// rely on template param name.
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::CPU>::IsSupported(op_to_check);
}]
>,
];
}
def TFL_GpuTargetOp : OpInterface<"GpuOpTargetInterface"> {
let description = [{
Interface for ops to run on GPU.
}];
let methods = [
InterfaceMethod<
[{Returns the cost of running this op on GPU.}],
// TODO(karimnosseir): Change to return Cost object instead.
"double", "GetOpCost", (ins "Operation*":$op_to_check), [{
// TODO(karimnosseir): Consider changing to another way that doesn't
// rely on template param name.
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::GPU>::GetCost(op_to_check);
}]
>,
InterfaceMethod<
[{Returns whether this op can be run on GPU.}],
"bool", "IsSupported", (ins "Operation*":$op_to_check), [{
// TODO(karimnosseir): Consider changing to another way that doesn't
// rely on template param name.
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::GPU>::IsSupported(op_to_check);
}]
>,
];
}
#endif // TFL_TARGET_HARDWARE_OP_INTERFACES

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,43 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
#include <string>
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
namespace tflite {
// Translates the given MLIR `module` into a FlatBuffer and stores the
// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to
// convert location of the op to name in flatbuffer. Returns true if translation
// fails, otherwise returns false.
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops,
bool emit_select_tf_ops,
bool emit_custom_ops);
// Same as the above but with a custom op name mapper.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_

View File

@ -0,0 +1,31 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_
#include <string>
// These flags are used to control the emission or not of different kinds of ops
// during the flatbuffer translation.
extern bool emit_builtin_tflite_ops;
extern bool emit_select_tf_ops;
extern bool emit_custom_ops;
// The flag to control whether to lower tensorlist ops into TF ops.
extern bool lower_tensor_list_ops;
// The flag to control whether debug info gets stripped on export.
extern bool strip_debug_info;
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_

View File

@ -44,39 +44,35 @@ limitations under the License.
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Translation.h" // TF:llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -100,45 +96,6 @@ using xla::StatusOr;
namespace errors = tensorflow::errors;
namespace tfl = mlir::TFL;
using llvm::cl::opt;
// Commandline flag to enable the control of flatbuffer import.
bool use_external_constant;
// Commandline flag to enable graph pruning.
bool experimental_prune_unreachable_nodes_unconditionally;
// NOLINTNEXTLINE
static opt<bool, true> use_external_constant_flag(
"use-external-constant",
llvm::cl::desc("Use external constant during flatbuffer import"),
llvm::cl::location(use_external_constant), llvm::cl::init(false));
// TODO(b/147111261): After the importer supports generic custom ops, we should
// change the flag to a more lightwise flag, e.g.
// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
// the operations.
// NOLINTNEXTLINE
static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
"experimental-prune-unreachable-nodes-unconditionally",
llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
llvm::cl::init(false));
// NOLINTNEXTLINE
static opt<std::string> input_arrays_flag(
"input-arrays",
llvm::cl::desc(
"List of input tensors, if different from the default inputs"),
llvm::cl::init(""));
// NOLINTNEXTLINE
static opt<std::string> output_arrays_flag(
"output-arrays",
llvm::cl::desc(
"List of output tensors, if different from the default outputs"),
llvm::cl::init(""));
namespace {
bool IsScalar(const TensorT& tensor) {
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
@ -1063,42 +1020,3 @@ OwningModuleRef tflite::FlatBufferToMlir(
return OwningModuleRef(module);
}
static OwningModuleRef FlatBufferFileToMlirTrans(
llvm::SourceMgr* source_mgr, MLIRContext* context,
bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
const llvm::MemoryBuffer* input =
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
std::string error;
auto loc =
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
// Parses input/output names from command line options.
std::vector<std::string> inputs;
std::vector<std::string> outputs;
// Use output parser since we only have tensor names.
if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) {
return emitError(loc, "parsing input array info failed ")
<< input_arrays_flag,
nullptr;
}
if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) {
return emitError(loc, "parsing output array info failed ")
<< output_arrays_flag,
nullptr;
}
return tflite::FlatBufferToMlir(
absl::string_view(input->getBufferStart(), input->getBufferSize()),
context, loc, use_external_constant, inputs, outputs,
experimental_prune_unreachable_nodes_unconditionally);
}
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
"tflite-flatbuffer-to-mlir",
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
return FlatBufferFileToMlirTrans(
&source_mgr, context, use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
});

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
#include "absl/strings/string_view.h"
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
namespace tflite {
// Converts a TFLite flatbuffer stored in `buffer` to a MLIR module

View File

@ -18,12 +18,12 @@ limitations under the License.
#include <vector>
#include "absl/strings/str_cat.h"
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/xla/statusor.h"

View File

@ -23,12 +23,12 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/schema/schema_generated.h"

View File

@ -23,8 +23,8 @@ limitations under the License.
#include <iostream>
#include <string>
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
#include "flatbuffers/minireflect.h" // TF:flatbuffers
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "flatbuffers/minireflect.h" // from @flatbuffers
#include "tensorflow/lite/schema/reflection/schema_generated.h"
namespace tflite {

File diff suppressed because it is too large Load Diff

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
namespace tflite {

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TFL_OP_INTERFACES
include "mlir/IR/OpBase.td"
include "tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td"
//===----------------------------------------------------------------------===//
// TFL op interface for stateful operands.

View File

@ -26,19 +26,19 @@ limitations under the License.
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/FoldUtils.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
@ -804,10 +804,10 @@ struct RemoveAdjacentReshape : public RewritePattern {
RemoveAdjacentReshape(MLIRContext *context)
: RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
LogicalResult match(Operation *op) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = thisOp.getOperand(0).getDefiningOp();
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
return isa_and_nonnull<ReshapeOp>(prevOp) ? success() : failure();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
@ -884,28 +884,27 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
explicit RemoveRedundantUnpackPack(MLIRContext *context)
: RewritePattern(PackOp::getOperationName(), 2, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TFL::PackOp pack_op = cast<TFL::PackOp>(op);
Operation *first_input = pack_op.getOperand(0).getDefiningOp();
if (!first_input) return matchFailure();
if (!first_input) return failure();
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
if (!input_unpack_op) return matchFailure();
if (!input_unpack_op) return failure();
// The unpack & pack should have the same axis & num inputs/outputs.
if (pack_op.axis() != input_unpack_op.axis() ||
pack_op.values_count() != input_unpack_op.num())
return matchFailure();
return failure();
const int total_pack_inputs = pack_op.getNumOperands();
if (total_pack_inputs != input_unpack_op.getNumResults())
return matchFailure();
if (total_pack_inputs != input_unpack_op.getNumResults()) return failure();
for (auto input_output :
llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
Value pack_input = std::get<0>(input_output);
Value unpack_output = std::get<1>(input_output);
// Make sure the ordering is the same for the pack op & unpack op.
if (pack_input != unpack_output) return matchFailure();
if (pack_input != unpack_output) return failure();
}
// Replace the pack's output to the unpack's input.
@ -913,7 +912,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
// At this point, we don't manually remove the redundant pack op & unpack op
// (we cannot actually), but trust the PatterRewriter to garbage collect
// these two ops.
return matchSuccess();
return success();
}
};
@ -1050,17 +1049,17 @@ struct DropFakeQuant : public RewritePattern {
explicit DropFakeQuant(MLIRContext *context)
: RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
LogicalResult match(Operation *op) const override {
// We only match the op with valid "minmax" attribute.
if (!HasValidMinMaxAttribute(op)) return matchFailure();
if (!HasValidMinMaxAttribute(op)) return failure();
// If all the users of this op have valid "minmax" attributes, it is matched
// and can be removed.
auto fakeQuantOp = cast<FakeQuantOp>(op);
for (auto *operand : fakeQuantOp.getResult().getUsers())
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
if (!HasValidMinMaxAttribute(operand)) return failure();
return matchSuccess();
return success();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
@ -1789,8 +1788,8 @@ struct WhileResultOperandsMatchAndImplicitCapture
: public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(WhileOp while_op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(WhileOp while_op,
PatternRewriter &rewriter) const override {
// Replace values simply passed through the body with extern values. The
// block arguments of body and while match and so the corresponding cond
// argument can be easily found.
@ -1843,7 +1842,7 @@ struct WhileResultOperandsMatchAndImplicitCapture
}
// Done if no values removed from blocks and operands & results match.
if (unchanged) return matchFailure();
if (unchanged) return failure();
// Replace with new While with matching operands and results.
Operation *op = while_op.getOperation();
@ -1866,7 +1865,7 @@ struct WhileResultOperandsMatchAndImplicitCapture
rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
new_body_yield);
return matchSuccess();
return success();
}
};

View File

@ -18,18 +18,18 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/Traits.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Dialect.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // TF:llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // TF:llvm-project
#include "mlir/Interfaces/SideEffects.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -49,9 +49,12 @@ class TensorFlowLiteDialect : public Dialect {
Location loc) override;
};
#include "tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
// Include all specializes estimators below this line
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimator.h.inc"
} // end namespace TFL
} // end namespace mlir

View File

@ -285,7 +285,10 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TFL_Dialect, mnemonic, !listconcat(traits,
[DeclareOpInterfaceMethods<TFL_RuntimeVerification>])> {
[DeclareOpInterfaceMethods<TFL_RuntimeVerification>,
// All TFL ops are supported on CPU.
DeclareOpInterfaceMethods<TFL_CpuTargetOp>
])> {
// FlatBuffer generation specific information.
// -------------------------------------------
// When generating the FlatBuffer output some operations have
@ -477,7 +480,10 @@ Note this is a custom op that is not supported in the standard runtime.
}
def TFL_AveragePool2DOp:
TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> {
TFL_Op<"average_pool_2d",
[NoSideEffect,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
let summary = "Average_pool_2d operator";
let description = [{
@ -690,7 +696,7 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
let extraClassDeclaration = [{
// StatefulOpInterface:
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 0; }
}];
}
@ -715,7 +721,7 @@ def TFL_DepthwiseConv2DOp :
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
let extraClassDeclaration = [{
// StatefulOpInterface:
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 3; }
}];
}
@ -1083,7 +1089,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
}
def TFL_DivOp : TFL_Op<"div", [ResultsBroadcastableShape, NoSideEffect]> {
def TFL_DivOp : TFL_Op<"div", [
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
let summary = "Division operator";
let description = [{

View File

@ -30,12 +30,12 @@ limitations under the License.
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Parser.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/flex/delegate.h"

View File

@ -19,11 +19,11 @@ limitations under the License.
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"

View File

@ -17,11 +17,11 @@ limitations under the License.
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"

View File

@ -18,11 +18,11 @@ limitations under the License.
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <ostream>
#include <utility>
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/lite/toco/model_flags.pb.h"

View File

@ -23,18 +23,18 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
#include "mlir/IR/AffineMap.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"

View File

@ -17,14 +17,14 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {

View File

@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
namespace mlir {
namespace TFL {

View File

@ -20,7 +20,7 @@ limitations under the License.
#define TF_Quantization
include "mlir/IR/OpBase.td"
include "mlir/Dialect/QuantOps/QuantOpsBase.td"
include "mlir/Dialect/Quant/QuantOpsBase.td"
//===----------------------------------------------------------------------===//
// QuantizedType definitions.

View File

@ -24,18 +24,18 @@ limitations under the License.
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/core/platform/logging.h"

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
namespace mlir {
namespace quant {

View File

@ -18,8 +18,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace OpTrait {

View File

@ -22,15 +22,15 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantizeUtils.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace quant {

View File

@ -23,18 +23,18 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir {
@ -82,17 +82,17 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
narrow_range(narrow_range),
is_signed(is_signed) {}
PatternMatchResult matchAndRewrite(quant::StatisticsOp op,
PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(quant::StatisticsOp op,
PatternRewriter& rewriter) const override {
Type expressed = op.getType().cast<ShapedType>().getElementType();
quant::QuantizedType quant_type;
SmallVector<double, 4> mins, maxs;
if (op.axisStats().hasValue()) {
int stats_num = op.axisStats()->getNumElements();
if (stats_num == 0 || stats_num % 2 != 0) return this->matchFailure();
if (stats_num == 0 || stats_num % 2 != 0) return failure();
auto stats = op.axisStats()->dyn_cast<DenseFPElementsAttr>();
if (!stats) return this->matchFailure();
if (!stats) return failure();
for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
mins.push_back(FloatAttr::getValueAsDouble(*it++));
@ -108,7 +108,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
quant::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax,
narrow_range, expressed, is_signed);
} else {
return this->matchFailure();
return failure();
}
rewriter.setInsertionPointAfter(op);
@ -119,7 +119,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
q.getOperation()->replaceUsesOfWith(dq, op.arg());
op.erase();
return this->matchSuccess();
return success();
}
private:
@ -156,16 +156,16 @@ struct QuantizationPattern : public RewritePattern {
error_tolerance(error_tolerance),
single_layer_verify(single_layer_verify) {}
PatternMatchResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
if (op->getNumResults() != 1) {
return matchFailure();
return failure();
}
Value quantized_value = op->getResult(0);
for (Operation* quantized_op : quantized_value.getUsers()) {
// If it is requantize op, we shouldn't rewrite this op.
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
return matchFailure();
return failure();
}
// If it is terminator or not quantizable or any ops form the mlir quant
@ -174,7 +174,7 @@ struct QuantizationPattern : public RewritePattern {
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::QuantizeCastOp>(quantized_op) ||
llvm::isa<quant::DequantizeCastOp>(quantized_op)) {
return matchFailure();
return failure();
}
// Collect all the quantized inputs and "clone" the matched op by these
@ -198,7 +198,7 @@ struct QuantizationPattern : public RewritePattern {
} else if (static_cast<const ConcretTy*>(this)->AllowHybridOperand()) {
inputs.push_back(operand);
} else {
return matchFailure();
return failure();
}
}
@ -234,7 +234,7 @@ struct QuantizationPattern : public RewritePattern {
outputs_replaced.insert({result, enumerated_result.index()});
output_types.push_back(result.getType());
} else {
return matchFailure();
return failure();
}
}
@ -299,7 +299,7 @@ struct QuantizationPattern : public RewritePattern {
}
}
}
return matchSuccess();
return success();
}
bool enable_verify;
@ -317,11 +317,11 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
explicit ConvertUnsignedToSigned(MLIRContext* context)
: OpRewritePattern<Q>(context, 1) {}
PatternMatchResult matchAndRewrite(Q op,
PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(Q op,
PatternRewriter& rewriter) const override {
Type output_type = op.getResult().getType();
auto qtype = QType::getQuantizedElementType(output_type);
if (!qtype || qtype.isSigned()) return this->matchFailure();
if (!qtype || qtype.isSigned()) return failure();
int num_bits = qtype.getStorageTypeIntegralWidth();
// This is a positive value, and will be applied on zero points and fixed
@ -352,14 +352,14 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
aqtype.getStorageTypeMin() - offset,
aqtype.getStorageTypeMax() - offset, op.getLoc());
} else {
return this->matchFailure();
return failure();
}
if (!new_qtype) return this->matchFailure();
if (!new_qtype) return failure();
Type new_output_type = new_qtype.castFromExpressedType(
QType::castToExpressedType(output_type));
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
return this->matchSuccess();
return success();
}
};

View File

@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {
namespace TF {

View File

@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -73,12 +73,12 @@ struct InsertQuantOpsAfterTFFakeQuantOp
MLIRContext *ctx)
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(TFFakeQuantOp tf_op,
PatternRewriter &rewriter) const override {
// We don't want to insert quantize/dequantize if the quantize op exists.
auto res = tf_op.outputs();
if (!res.hasOneUse() || isa<quant::QuantizeCastOp>(*res.user_begin()))
return this->matchFailure();
return failure();
// Extract the min/max constant values from the operands. We also consider
// a special case that there are tf.Identity ops between the min/max
@ -95,8 +95,8 @@ struct InsertQuantOpsAfterTFFakeQuantOp
max = tf_op.max();
rewriter.eraseOp(id2);
}
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
if (!matchPattern(min, m_Constant(&min_value))) return failure();
if (!matchPattern(max, m_Constant(&max_value))) return failure();
int quant_dim = -1;
if (PerAxis) {
@ -114,7 +114,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp
TypeAttr qtype = quant::GetQuantizedTypeAttr(
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
narrow_range, /*is_signed=*/true);
if (!qtype) this->matchFailure();
if (!qtype) failure();
// Finally, use the quantization parameter to create the quantize and
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
@ -127,7 +127,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp
value.replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
return this->matchSuccess();
return success();
}
};

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Operator.h" // TF:llvm-project
#include "mlir/TableGen/Operator.h" // from @llvm-project
using llvm::LessRecord;
using llvm::raw_ostream;

View File

@ -1,3 +1,8 @@
load(
"//third_party/mlir:tblgen.bzl",
"gentbl",
)
package(
default_visibility = [
":friends",
@ -18,6 +23,8 @@ package_group(
cc_library(
name = "hlo_xla_quantization_passes",
srcs = [
"cpu_kernel_fusion.cc",
"generated_cpu_kernel_fusion.inc",
"materialize.cc",
"op_quant_spec.inc",
"propagate.cc",
@ -36,6 +43,8 @@ cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
@ -52,7 +61,6 @@ cc_library(
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
@ -62,6 +70,24 @@ cc_library(
"//tensorflow/core/platform:status",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
)
gentbl(
name = "cpu_kernel_fusion_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"generated_cpu_kernel_fusion.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "cpu_kernel_fusion.td",
td_srcs = [
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/xla:hlo_ops_td_files",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
],
)

View File

@ -0,0 +1,252 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <initializer_list>
#include <iterator>
#include <numeric>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/client/lib/quantize.h"
#define DEBUG_TYPE "quant-kernel-fusion"
constexpr int kFakeQuantOperandsNum = 5;
constexpr int kFakeQuantPerChannelOperandsNum = 6;
namespace mlir {
namespace xla_hlo {
namespace {
TypeAttr GetQuantSpec(Operation* op) {
auto fake_quant = llvm::dyn_cast_or_null<CustomCallOp>(op);
if (!fake_quant || fake_quant.getNumOperands() < kFakeQuantOperandsNum ||
fake_quant.getNumOperands() > kFakeQuantPerChannelOperandsNum ||
fake_quant.call_target_name() != "fake_quant_with_min_max_vars")
return {};
DenseFPElementsAttr min, max;
DenseIntElementsAttr bit_width, narrow_range, quant_dim;
if (!matchPattern(fake_quant.getOperand(1), m_Constant(&min)) ||
!matchPattern(fake_quant.getOperand(2), m_Constant(&max)) ||
!matchPattern(fake_quant.getOperand(3), m_Constant(&bit_width)) ||
!matchPattern(fake_quant.getOperand(4), m_Constant(&narrow_range)))
return {};
auto bit_width_val = (*bit_width.attr_value_begin()).cast<IntegerAttr>();
auto narrow_range_val = (*narrow_range.int_value_begin()).getSExtValue();
int quant_dim_val = -1;
if (fake_quant.getNumOperands() == kFakeQuantPerChannelOperandsNum &&
matchPattern(fake_quant.getOperand(kFakeQuantPerChannelOperandsNum - 1),
m_Constant(&quant_dim))) {
quant_dim_val = (*quant_dim.int_value_begin()).getSExtValue();
}
OpBuilder builder(op);
Type input_type =
fake_quant.getOperand(0).getType().cast<ShapedType>().getElementType();
return quant::GetQuantizedTypeAttr(
builder, input_type, min, max, quant_dim_val, bit_width_val,
builder.getBoolAttr(narrow_range_val), /*is_signed=*/true);
}
// Collects input values from outside for 'ops'.
void CollectInputs(llvm::ArrayRef<Operation*> ops,
llvm::SmallVectorImpl<Value>* inputs,
llvm::SmallVectorImpl<Attribute>* input_specs) {
for (Operation* op : ops) {
for (Value operand : op->getOperands()) {
if (std::find(inputs->begin(), inputs->end(), operand) != inputs->end()) {
continue;
}
if (Operation* def_op = operand.getDefiningOp()) {
if (std::find(ops.begin(), ops.end(), def_op) == ops.end()) {
inputs->push_back(operand);
}
} else { // argument value
inputs->push_back(operand);
}
}
}
for (Value input : *inputs) {
ShapedType input_type = input.getType().cast<ShapedType>();
if (TypeAttr spec = GetQuantSpec(input.getDefiningOp())) {
input_specs->push_back(spec);
} else {
input_specs->push_back(TypeAttr::get(input_type.getElementType()));
}
}
}
// Collects values that are produced by 'ops' and have use outside of 'ops'.
// TODO(fengliuai): if it is a single user and QDQ, write that to the specs.
void CollectRets(llvm::ArrayRef<Operation*> ops,
llvm::SmallVectorImpl<Value>* rets,
llvm::SmallVectorImpl<Type>* ret_types,
llvm::SmallVectorImpl<Attribute>* ret_specs) {
for (Operation* op : ops) {
for (Value result : op->getResults()) {
for (Operation* user : result.getUsers()) {
// If there are any user outside of 'ops'
if (std::find(ops.begin(), ops.end(), user) == ops.end()) {
ShapedType ret_type = result.getType().cast<ShapedType>();
rets->push_back(result);
ret_types->push_back(ret_type);
if (TypeAttr spec = GetQuantSpec(user)) {
ret_specs->push_back(spec);
} else {
ret_specs->push_back(TypeAttr::get(ret_type.getElementType()));
}
break;
}
}
}
}
}
llvm::SmallVector<Value, 0> fuseOps(PatternRewriter* rewriter,
const std::initializer_list<Value>& results,
StringRef kernel) {
// Collect all the operations to be fused.
llvm::SmallVector<Operation*, 4> fused;
llvm::SmallVector<Location, 4> locs;
fused.reserve(results.size());
locs.reserve(results.size());
for (auto value : results) {
Operation* op = value.getDefiningOp();
fused.push_back(op);
locs.push_back(op->getLoc());
}
// Collect inputs from outside to 'ops'.
llvm::SmallVector<Value, 4> inputs;
llvm::SmallVector<Attribute, 4> input_specs;
CollectInputs(fused, &inputs, &input_specs);
// Collect outputs from 'ops' to outside.
llvm::SmallVector<Value, 4> rets;
llvm::SmallVector<Type, 4> ret_types;
llvm::SmallVector<Attribute, 4> ret_specs;
CollectRets(fused, &rets, &ret_types, &ret_specs);
// Create the region op with the return.
auto region = rewriter->create<quant::QuantizeRegionOp>(
rewriter->getFusedLoc(locs), ret_types, inputs,
rewriter->getArrayAttr(input_specs), rewriter->getArrayAttr(ret_specs),
kernel);
auto* body = new Block();
region.body().push_back(body);
OpBuilder builder(body);
BlockAndValueMapping mapping;
// Make block arguments and add it to the block value mapping.
for (Value input : inputs) {
mapping.map(input, body->addArgument(input.getType()));
}
// Clone the operations 'ops' to the region.
for (Operation* op : fused) {
builder.clone(*op, mapping);
}
llvm::SmallVector<Value, 4> new_rets;
new_rets.reserve(rets.size());
for (auto ret : llvm::enumerate(rets)) {
Value new_ret = mapping.lookupOrNull(ret.value());
assert(new_ret && "couldn't find return value.");
new_rets.push_back(new_ret);
ret.value().replaceAllUsesWith(region.getResult(ret.index()));
}
builder.create<quant::ReturnOp>(builder.getUnknownLoc(), new_rets);
LLVM_DEBUG({
assert(region.verify().Success && "failed to create quant region.");
llvm::dbgs() << "\ncreated region: ";
region.print(llvm::dbgs());
llvm::dbgs() << "\n\n\n";
});
SmallVector<Value, 0> new_values(fused.back()->getNumResults());
return new_values;
}
struct CpuKernelFusionPass : public FunctionPass<CpuKernelFusionPass> {
explicit CpuKernelFusionPass() = default;
CpuKernelFusionPass(const CpuKernelFusionPass&) {}
void runOnFunction() override;
private:
LogicalResult fuseCpuKernels(Operation* op);
};
#include "tensorflow/compiler/mlir/lite/quantization/xla/generated_cpu_kernel_fusion.inc"
LogicalResult CpuKernelFusionPass::fuseCpuKernels(Operation* op) {
MLIRContext* ctx = op->getContext();
OwningRewritePatternList patterns;
populateWithGenerated(ctx, &patterns);
ConversionTarget target(*ctx);
target.addLegalDialect<quant::QuantizationDialect>();
target.addLegalOp<CallOp, ModuleOp, FuncOp, ModuleTerminatorOp,
::mlir::ReturnOp>();
return applyPartialConversion(op, target, patterns);
}
void CpuKernelFusionPass::runOnFunction() {
if (failed(fuseCpuKernels(getFunction()))) signalPassFailure();
}
} // namespace
// Creates an instance of the xla_hlo cpu kernel fusion pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateCpuKernelFusionPass() {
return std::make_unique<CpuKernelFusionPass>();
}
static PassRegistration<CpuKernelFusionPass> pass(
"xla-hlo-cpu-fusion", "Fuse xla hlo ops into cpu kernels");
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,24 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
class Fused2Ops<string kernel> : NativeCodeCall<
"fuseOps(&$_builder, {$0, $1}, \"" # kernel # "\")">;
class Fused3Ops<string kernel> : NativeCodeCall<
"fuseOps(&$_builder, {$0, $1, $2}, \"" # kernel # "\")">;
def : Pat<(HLO_AddOp:$add (HLO_MulOp:$mul $_, $_, $_), $_, $_),
(Fused2Ops<"generic.mul_add"> $mul, $add)>;

View File

@ -25,14 +25,14 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
@ -58,15 +58,15 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
explicit RewriteDequantize(int64_t size, MLIRContext *context)
: OpRewritePattern<quant::DequantizeCastOp>(context), size_(size) {}
PatternMatchResult matchAndRewrite(quant::DequantizeCastOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(quant::DequantizeCastOp op,
PatternRewriter &rewriter) const override {
// quant.dcast
// xla_hlo dequantize only takes min/max, so let's recover them from
// the quantization parameters.
Value dcast = op.arg();
auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType());
if (!type || !type.isa<quant::UniformQuantizedType>()) {
return matchFailure();
return failure();
}
auto qtype = type.cast<quant::UniformQuantizedType>();
double scale = qtype.getScale();
@ -77,7 +77,7 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
// quant.qcast
auto qcast =
llvm::dyn_cast_or_null<quant::QuantizeCastOp>(dcast.getDefiningOp());
if (!qcast) return matchFailure();
if (!qcast) return failure();
// constant
DenseFPElementsAttr attr;
@ -88,7 +88,7 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
attr.getNumElements() <= size_ ||
attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) {
op.getResult().replaceAllUsesWith(qcast.arg());
return matchSuccess();
return success();
}
// TODO(fengliuai): implement transpose if it has high dimension.
@ -96,7 +96,7 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
auto quantized_result =
quant::Quantize(attr, qtype).dyn_cast_or_null<DenseIntElementsAttr>();
if (!quantized_result) {
return matchFailure();
return failure();
}
// Pack the uint8 bits to uint32. The shape is changed from from
@ -133,7 +133,7 @@ class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
// Convert bf16 output back to f32
rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getResult().getType(),
dequantize);
return matchSuccess();
return success();
}
private:

View File

@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {
namespace xla_hlo {

View File

@ -21,10 +21,10 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"

View File

@ -14,26 +14,38 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
namespace mlir {
namespace xla_hlo {
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
return true;
}();
(void)init_once;
}
// Quantizes the model in the computation.
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
xla::XlaComputation* computation) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
computation->Snapshot());
RegisterDialects();
MLIRContext context;
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
auto status = xla::ConvertHloToMlirHlo(

View File

@ -0,0 +1,46 @@
// RUN: tf-opt -xla-hlo-cpu-fusion %s | FileCheck %s
// CHECK-LABEL: @mul_add_source
func @mul_add_source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
// CHECK: %[[region:.*]] = "quant.region"(%arg0, %arg1, %arg2) ( {
// CHECK: ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
// CHECK: "quant.return"(%[[add]]) : (tensor<4xf32>) -> ()
// CHECK: }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[region]] : tensor<4xf32>
}
// CHECK-LABEL: @mul_add_annotated
func @mul_add_annotated(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) {
%cst = constant dense<0.0> : tensor<f32>
%cst_0 = constant dense<255.0> : tensor<f32>
%cst_1 = constant dense<8> : tensor<i32>
%cst_2 = constant dense<false> : tensor<i1>
%qin = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
has_side_effect = false, name = "custom-call.1"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
%qw = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
has_side_effect = false, name = "custom-call.2"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
%0 = "xla_hlo.multiply"(%qin, %qw) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
%r = "xla_hlo.custom_call"(%1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
has_side_effect = false, name = "custom-call.3"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
return %r : tensor<2x4xf32>
// CHECK: %[[region:.*]] = "quant.region"
// CHECK: ^bb0(%arg3: tensor<2x4xf32>, %arg4: tensor<2x4xf32>, %arg5: tensor<2x4xf32>): // no predecessors
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
// CHECK: "quant.return"(%[[add]]) : (tensor<2x4xf32>) -> ()
// CHECK: }) {input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32],
// CHECK-SAME: logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]} :
// CHECK-SAME: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[r:.*]] = "xla_hlo.custom_call"(%[[region]]
// CHECK: return %[[r]] : tensor<2x4xf32>
}

View File

@ -6,49 +6,49 @@ func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
// CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32,
// CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16>
// CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[cast]] : tensor<2x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[cast]] : tensor<2x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
%w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32>
%q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
%mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
return %mul: tensor<2x4xf32>
}
// CHECK-LABEL: func @quantize_small
func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<1x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<1x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<1x4xf32>
%w = constant dense<1.0> : tensor<1x4xf32>
%q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<1x4xf32>
%mul = xla_hlo.mul %arg0, %dq : tensor<1x4xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<1x4xf32>
return %mul: tensor<1x4xf32>
}
// CHECK-LABEL: func @quantize_non_cst
func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %arg0 : tensor<2x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %arg0 : tensor<2x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
%q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
%mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
return %mul: tensor<2x4xf32>
}
// CHECK-LABEL: func @quantize_non_4x
func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<2x5xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<2x5xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x5xf32>
%w = constant dense<1.0> : tensor<2x5xf32>
%q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x5xf32>
%mul = xla_hlo.mul %arg0, %dq : tensor<2x5xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x5xf32>
return %mul: tensor<2x5xf32>
}

View File

@ -5,10 +5,10 @@ func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32>
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x2xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[dq]] : tensor<2x2xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[dq]] : tensor<2x2xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x2xf32>
%w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32>
%mul = xla_hlo.mul %arg0, %w : tensor<2x2xf32>
%mul = xla_hlo.multiply %arg0, %w : tensor<2x2xf32>
return %mul: tensor<2x2xf32>
}

View File

@ -17,14 +17,14 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/types.pb.h"

View File

@ -0,0 +1,15 @@
// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s --dump-input-on-failure
// CHECK: error: 'tf.MyCustomOp' op is neither a custom op nor a flex op
// CHECK: error: failed while converting: 'main'
// CHECK: Ops that need custom implementation (enabled via setting the -emit-custom-ops flag):
// CHECK: tf.MyCustomOp {name = "MyCustomOp"}
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
%0 = "tfl.pseudo_const" () {name = "Const", value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
%1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE", name = "mul"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%2 = "tf.MyCustomOp"(%1, %0) {name = "MyCustomOp"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%3 = "tfl.exp"(%2) {name = "exp"} : (tensor<4xf32>) -> tensor<4xf32>
return %3 : tensor<4xf32>
}

View File

@ -1,8 +1,9 @@
// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s
// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s --dump-input-on-failure
// CHECK: error: 'tf.Div' op is neither a custom op nor a flex op
// CHECK: error: failed while converting: 'main'
// CHECK: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): Div.
// CHECK: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag):
// CHECK: tf.Div {name = "div"}
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):

View File

@ -15,11 +15,11 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
namespace tensorflow {

View File

@ -20,16 +20,16 @@ limitations under the License.
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"

View File

@ -19,13 +19,13 @@ limitations under the License.
#include <unordered_set>
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Parser.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/stream_executor/lib/statusor.h"

View File

@ -25,9 +25,9 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"

View File

@ -16,9 +16,9 @@ limitations under the License.
// This transformation pass convert dense tensor to sparse format.
#include "absl/memory/memory.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
//===----------------------------------------------------------------------===//

View File

@ -21,12 +21,12 @@ limitations under the License.
#include <cstdint>
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -74,31 +74,31 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
PatternRewriter& rewriter) const;
public:
PatternMatchResult matchAndRewrite(Conv2dOpTy op,
PatternRewriter& rewriter) const override;
LogicalResult matchAndRewrite(Conv2dOpTy op,
PatternRewriter& rewriter) const override;
};
template <typename Conv2dOpTy>
PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
Conv2dOpTy op, PatternRewriter& rewriter) const {
// Make sure Conv2D has 'VALID' padding.
if (op.template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
return Pattern::matchFailure();
return failure();
}
// Make sure dilations are all ones if set.
const ArrayAttr& dilations =
op.template getAttrOfType<ArrayAttr>("dilations");
if (dilations && !TFIntListIsAllOnes(dilations)) {
return Pattern::matchFailure();
return failure();
}
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
// `Squeeze` op.
Operation* prev_op = op.getOperation()->getPrevNode();
if (!prev_op) return Pattern::matchFailure();
if (!prev_op) return failure();
Operation* next_op = op.getOperation()->getNextNode();
if (!next_op) return Pattern::matchFailure();
if (!next_op) return failure();
TF::ExpandDimsOp expand_op;
TF::SqueezeOp squeeze_op;
@ -107,7 +107,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
// Expand/Squeeze op must come in pair.
return Pattern::matchFailure();
return failure();
}
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
@ -119,24 +119,24 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
(*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
.getSExtValue();
} else {
return Pattern::matchFailure();
return failure();
}
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
auto squeeze_dims = squeeze_op.squeeze_dims();
if (squeeze_dims.size() != 1 ||
squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) {
return Pattern::matchFailure();
return failure();
}
// Update previous/next op pointer.
prev_op = prev_op->getPrevNode();
if (!prev_op) return Pattern::matchFailure();
if (!prev_op) return failure();
next_op = next_op->getNextNode();
if (!next_op) return Pattern::matchFailure();
if (!next_op) return failure();
}
// SpaceToBatchND op.
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure();
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return failure();
// TODO(b/149936532): Check `padding` input, currently ignored.
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
@ -148,7 +148,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
if (llvm::isa<TF::PadOp>(next_op)) {
pad_op = llvm::cast<TF::PadOp>(next_op);
next_op = next_op->getNextNode();
if (!next_op) return Pattern::matchFailure();
if (!next_op) return failure();
}
// BatchToSpaceND + BiasAdd.
@ -160,8 +160,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
// Must be BiasAdd + BatchToSpaceND.
biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
next_op = next_op->getNextNode();
if (!next_op || !llvm::isa<TF::BatchToSpaceNDOp>(next_op))
return Pattern::matchFailure();
if (!next_op || !llvm::isa<TF::BatchToSpaceNDOp>(next_op)) return failure();
bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
} else if (llvm::isa<TF::BatchToSpaceNDOp>(next_op)) {
// BatchToSpaceND + (optional) BiasAdd.
@ -172,12 +171,12 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
final_op_is_bts = false;
}
} else {
return Pattern::matchFailure();
return failure();
}
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
stb_op.block_shape(), bts_op.block_shape(), rewriter);
if (!dilations_attr.hasValue()) return Pattern::matchFailure();
if (!dilations_attr.hasValue()) return failure();
op.setAttr("dilations", dilations_attr.getValue());
// Padding is set to 'SAME' when `stb_op` has non-zero paddings.
@ -228,7 +227,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
}
stb_op.getResult().dropAllUses();
return Pattern::matchSuccess();
return success();
}
template <typename Conv2dOpTy>

View File

@ -21,26 +21,26 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"

View File

@ -15,23 +15,23 @@ limitations under the License.
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {

View File

@ -28,17 +28,17 @@ limitations under the License.
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -98,12 +98,12 @@ bool HasSameStaticShapes(Operation* op) {
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
#define DECL_CONVERT_OP(tf_op) \
struct ConvertTF##tf_op##Op : public RewritePattern { \
explicit ConvertTF##tf_op##Op(MLIRContext* context) \
: RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
PatternMatchResult matchAndRewrite( \
Operation* op, PatternRewriter& rewriter) const override; \
#define DECL_CONVERT_OP(tf_op) \
struct ConvertTF##tf_op##Op : public RewritePattern { \
explicit ConvertTF##tf_op##Op(MLIRContext* context) \
: RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
LogicalResult matchAndRewrite(Operation* op, \
PatternRewriter& rewriter) const override; \
}
// TODO(antiagainst): Define this pattern in a table-driven manner once variadic
@ -127,14 +127,14 @@ DECL_CONVERT_OP(BroadcastTo);
#undef DECL_CONVERT_OP
PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite(
LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto random_uniform_op = cast<TF::RandomUniformOp>(op);
if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) {
return matchFailure();
return failure();
}
if (!random_uniform_op.dtype().isF32()) {
return matchFailure();
return failure();
}
typedef tensorflow::random::UniformDistribution<
tensorflow::random::PhiloxRandom, float>
@ -149,7 +149,7 @@ PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite(
random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) {
if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) {
if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) {
return matchFailure();
return failure();
}
num_elements = output_type.getNumElements();
size_t offset = 0;
@ -165,13 +165,13 @@ PatternMatchResult ConvertTFRandomUniformOp::matchAndRewrite(
}
auto output_data = DenseFPElementsAttr::get(output_type, data);
rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data);
return matchSuccess();
return success();
}
}
return matchFailure();
return failure();
}
PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
LogicalResult ConvertTFConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatOp>(op);
@ -180,17 +180,17 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
// Extract axis attribute from constant concat_dims tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
return matchFailure();
return failure();
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis),
fused_activation_function);
return matchSuccess();
return success();
}
PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
@ -198,15 +198,14 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
auto output_type = tf_concat_op.output().getType();
// Extract axis attribute from constant axis tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
return matchFailure();
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<ConcatenationOp>(
op, output_type, values, ExtractSingleElementAsInteger(axis),
fused_activation_function);
return matchSuccess();
return success();
}
// The following is effectively:
@ -215,11 +214,11 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
// ConstBoolAttrTrue:$transpose_b),
// (TFL_FullyConnectedOp:$__0 $a, $b,
// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>;
PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
LogicalResult ConvertTFMatMulOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_matmul_op = cast<TF::MatMulOp>(op);
if (tf_matmul_op.transpose_a()) return matchFailure();
if (!tf_matmul_op.transpose_b()) return matchFailure();
if (tf_matmul_op.transpose_a()) return failure();
if (!tf_matmul_op.transpose_b()) return failure();
Type output_type = tf_matmul_op.getResult().getType();
// TODO(jpienaar): Follow up post shuffle discussion.
@ -230,10 +229,10 @@ PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
op->getOperand(1), no_input, rewriter.getStringAttr("NONE"),
rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false));
rewriter.replaceOp(op, {fc_op.getResult(0)});
return matchSuccess();
return success();
}
PatternMatchResult ConvertTFPackOp::matchAndRewrite(
LogicalResult ConvertTFPackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_pack_op = cast<TF::PackOp>(op);
@ -245,10 +244,10 @@ PatternMatchResult ConvertTFPackOp::matchAndRewrite(
rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count,
axis);
return matchSuccess();
return success();
}
PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
LogicalResult ConvertTFReshapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_reshape_op = cast<TF::ReshapeOp>(op);
@ -269,10 +268,10 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
}
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
input, shape);
return matchSuccess();
return success();
}
PatternMatchResult ConvertTFSplitOp::matchAndRewrite(
LogicalResult ConvertTFSplitOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_split_op = cast<TF::SplitOp>(op);
@ -284,10 +283,10 @@ PatternMatchResult ConvertTFSplitOp::matchAndRewrite(
rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, output_types,
tf_split_op.split_dim(),
tf_split_op.value(), num_split);
return matchSuccess();
return success();
}
PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
LogicalResult ConvertTFSplitVOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_splitv_op = cast<TF::SplitVOp>(op);
@ -299,7 +298,7 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
rewriter.replaceOpWithNewOp<TFL::SplitVOp>(
op, output_types, tf_splitv_op.value(), tf_splitv_op.size_splits(),
tf_splitv_op.split_dim(), num_split);
return matchSuccess();
return success();
}
Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
@ -330,7 +329,7 @@ Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
return rewriter.create<ConstantOp>(op->getLoc(), type, attr);
}
PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
auto ranked_input_type =
@ -352,7 +351,7 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
tf_strided_slice_op.new_axis_mask().getSExtValue()),
rewriter.getI32IntegerAttr(
tf_strided_slice_op.shrink_axis_mask().getSExtValue()));
return matchSuccess();
return success();
}
int num_input_dims = ranked_input_type.getRank();
@ -382,10 +381,10 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
tf_strided_slice_op.new_axis_mask().getSExtValue()),
rewriter.getI32IntegerAttr(
tf_strided_slice_op.shrink_axis_mask().getSExtValue()));
return matchSuccess();
return success();
}
PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
LogicalResult ConvertTFUnpackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_unpack_op = cast<TF::UnpackOp>(op);
@ -397,7 +396,7 @@ PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue());
rewriter.replaceOpWithNewOp<UnpackOp>(op, output_types, input, num, axis);
return matchSuccess();
return success();
}
// MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute
@ -449,25 +448,25 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
return true;
}
PatternMatchResult ConvertTFMatrixDiagV2Op::matchAndRewrite(
LogicalResult ConvertTFMatrixDiagV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV2Op>(op, &rewriter))
return matchSuccess();
return matchFailure();
return success();
return failure();
}
PatternMatchResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
LogicalResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV3Op>(op, &rewriter))
return matchSuccess();
return matchFailure();
return success();
return failure();
}
// TF Lite doesn't support Assert, we just drop the assert from the graph.
PatternMatchResult ConvertTFAssertOp::matchAndRewrite(
LogicalResult ConvertTFAssertOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
@ -545,7 +544,7 @@ StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
}
PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op);
@ -553,7 +552,7 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
&rewriter, op->getLoc(),
tf_reciprocal_op.x().getType().cast<ShapedType>(), 1);
if (!status_or_const_op.ok()) {
return matchFailure();
return failure();
}
StringAttr fused_activation_function =
@ -562,10 +561,10 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
rewriter.replaceOpWithNewOp<TFL::DivOp>(op, status_or_const_op.ValueOrDie(),
tf_reciprocal_op.x(),
fused_activation_function);
return matchSuccess();
return success();
}
PatternMatchResult ConvertTFBroadcastToOp::matchAndRewrite(
LogicalResult ConvertTFBroadcastToOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
auto element_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
@ -574,7 +573,7 @@ PatternMatchResult ConvertTFBroadcastToOp::matchAndRewrite(
auto status_or_const_op =
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1);
if (!status_or_const_op.ok()) {
return matchFailure();
return failure();
}
auto tfl_fill_op = rewriter.create<TFL::FillOp>(
@ -587,7 +586,7 @@ PatternMatchResult ConvertTFBroadcastToOp::matchAndRewrite(
rewriter.replaceOpWithNewOp<TFL::MulOp>(
op, output_type, tf_broadcast_to_op.input(), tfl_fill_op,
fused_activation_function);
return matchSuccess();
return success();
}
// Legalize unidirectional sequence lstm.
@ -595,11 +594,11 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
: RewritePattern(kUnidirectionalSequenceLstm, 1, context) {}
PatternMatchResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
auto tflite_indices_attr =
op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
if (!tflite_indices_attr) return matchFailure();
if (!tflite_indices_attr) return failure();
SmallVector<int64_t, 20> tflite_indices;
for (auto index_attr : tflite_indices_attr.getValue()) {
@ -654,7 +653,7 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
// Rewire the output.
op->getResult(2).replaceAllUsesWith(lstm_op.getResult());
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};
@ -663,24 +662,24 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context)
: RewritePattern(kUnidirectionalSequenceRnn, 1, context) {}
PatternMatchResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
auto tflite_indices_attr =
op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
if (!tflite_indices_attr) return matchFailure();
if (!tflite_indices_attr) return failure();
if (op->getNumOperands() != 5) {
op->emitError()
<< "We're expecting 5 inputs for UnidirectionalSequenceRNN, only "
<< op->getNumOperands() << " provided";
return matchFailure();
return failure();
}
if (op->getNumResults() != 2) {
op->emitError()
<< "We're expecting 2 inputs for UnidirectionalSequenceRNN, only "
<< op->getNumResults() << " found";
return matchFailure();
return failure();
}
// Populate inputs.
@ -714,7 +713,7 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
op->getResult(1).replaceAllUsesWith(rnn_op.getResult());
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};

View File

@ -15,12 +15,12 @@ limitations under the License.
// Converts TF While to TFL While with single call in body and cond.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

View File

@ -19,11 +19,11 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"

View File

@ -31,28 +31,28 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
@ -175,33 +175,33 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
LogicalResult matchAndRewrite(
TF::ConstOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Verify that the opaque elements attribute contains tensor of type variant
// and scalar shape. The variant type should hold a TensorList.
auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
if (!opaque_attr) return matchFailure();
if (!opaque_attr) return failure();
tensorflow::Tensor tensor;
if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
return matchFailure();
if (tensor.dtype() != tensorflow::DT_VARIANT) return matchFailure();
return failure();
if (tensor.dtype() != tensorflow::DT_VARIANT) return failure();
if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
return matchFailure();
return failure();
const tensorflow::TensorList *list =
tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
if (!list) return matchFailure();
if (!list) return failure();
// Verify output type is variant and contains exactly one ranked subtypes.
auto variant_ty =
getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
if (!variant_ty) return matchFailure();
if (!variant_ty) return failure();
ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
if (subtypes.size() != 1) return matchFailure();
if (subtypes.size() != 1) return failure();
RankedTensorType list_element_ty =
subtypes.front().dyn_cast<RankedTensorType>();
if (!list_element_ty) return matchFailure();
if (!list_element_ty) return failure();
// Extract tensor elements for the TensorList and construct result type
// based on the number of elements and element shape.
@ -225,9 +225,9 @@ struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
tensorflow::Tensor tensor(list->element_dtype,
tensorflow::TensorShape(tf_shape));
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
if (!attr_or.ok()) return matchFailure();
if (!attr_or.ok()) return failure();
rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
return matchSuccess();
return success();
}
// Extract individual tensor list element and combine them using the tf.Pack
@ -237,14 +237,14 @@ struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
values.reserve(tensors.size());
for (const tensorflow::Tensor &tensor : tensors) {
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
if (!attr_or.ok()) return matchFailure();
if (!attr_or.ok()) return failure();
auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
values.push_back(value);
}
rewriter.replaceOpWithNewOp<TF::PackOp>(
op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
return matchSuccess();
return success();
}
};
@ -264,7 +264,7 @@ struct ConvertTensorListSetItem
// (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
PatternMatchResult matchAndRewrite(
LogicalResult matchAndRewrite(
TF::TensorListSetItemOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
@ -311,7 +311,7 @@ struct ConvertTensorListSetItem
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, input.getType(), scalar_zero,
ArrayRef<Value>({slice1, expanded_item, slice2}));
return matchSuccess();
return success();
}
};
@ -330,7 +330,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
// Rewrites the original op into `tf.fill`. The result tensor shape is
// [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0.
PatternMatchResult matchAndRewrite(
LogicalResult matchAndRewrite(
OpT op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type dtype = op.element_dtype();
@ -342,7 +342,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
"transformation pass");
return ConversionPattern::matchFailure();
return failure();
}
Value element_shape = operands[0];
@ -354,7 +354,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
op.emitError(
"requires element_shape to be 1D tensor during TF Lite "
"transformation pass");
return ConversionPattern::matchFailure();
return failure();
}
}
@ -434,7 +434,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
auto zero = rewriter.create<ConstantOp>(loc, zero_type, zero_attr);
rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
return Pattern::matchSuccess();
return success();
}
};
@ -472,7 +472,7 @@ struct ConvertTensorListPushBack
: public OpConversionPattern<TF::TensorListPushBackOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
LogicalResult matchAndRewrite(
TF::TensorListPushBackOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Value input_handle = operands[0];
@ -498,7 +498,7 @@ struct ConvertTensorListPushBack
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, result_type, scalar_zero,
ArrayRef<Value>({input_handle, expanded_item}));
return matchSuccess();
return success();
}
};
@ -516,7 +516,7 @@ struct ConvertTensorListResize
: public OpConversionPattern<TF::TensorListResizeOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
LogicalResult matchAndRewrite(
TF::TensorListResizeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Value input_handle = operands[0];
@ -582,7 +582,7 @@ struct ConvertTensorListResize
/*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
/*output_shapes=*/rewriter.getStrArrayAttr({"{}"}),
/*is_stateless=*/rewriter.getBoolAttr(true));
return matchSuccess();
return success();
}
private:
@ -660,14 +660,14 @@ struct ConvertTensorListGetItem
: public OpConversionPattern<TF::TensorListGetItemOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
LogicalResult matchAndRewrite(
TF::TensorListGetItemOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Value input = operands[0];
Value index = operands[1];
rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
rewriter.getBoolAttr(true));
return matchSuccess();
return success();
}
};
@ -675,7 +675,7 @@ struct ConvertTensorListLength
: public OpConversionPattern<TF::TensorListLengthOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
LogicalResult matchAndRewrite(
TF::TensorListLengthOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
@ -687,7 +687,7 @@ struct ConvertTensorListLength
rewriter.replaceOpWithNewOp<TF::GatherOp>(
op, op.getType(), shape, CreateI32SplatConst(loc, &rewriter, {}, 0),
/*validate_indices=*/true_attr);
return matchSuccess();
return success();
}
};
@ -695,7 +695,7 @@ struct ConvertTensorListStack
: public OpConversionPattern<TF::TensorListStackOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
LogicalResult matchAndRewrite(
TF::TensorListStackOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
@ -713,7 +713,7 @@ struct ConvertTensorListStack
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
// If no constant is spotted, just forward the operand.
rewriter.replaceOp(op, {input});
return matchSuccess();
return success();
}
RankedTensorType shape_type =
@ -726,20 +726,20 @@ struct ConvertTensorListStack
RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
rewriter.replaceOpWithNewOp<TF::ReshapeOp>(op, result_type, input,
new_shape);
return matchSuccess();
return success();
}
};
struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
LogicalResult matchAndRewrite(
TF::IdentityOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Value input = operands[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
op.getAttrs());
return matchSuccess();
return success();
}
};
@ -804,7 +804,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
LogicalResult matchAndRewrite(
TF::WhileOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<Type, 8> result_types;
@ -828,7 +828,7 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
UpdateFunctionTypes(cloned);
rewriter.replaceOp(op, cloned.getResults());
return matchSuccess();
return success();
}
};

View File

@ -31,14 +31,14 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -211,18 +211,17 @@ DenseElementsAttr GetShape(Value output_val) {
struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
using OpRewritePattern<TFL::AddOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(TFL::AddOp add_op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(TFL::AddOp add_op,
PatternRewriter &rewriter) const override {
// Match Add.
DenseElementsAttr added_value;
Value constant_val = add_op.rhs();
if (!matchPattern(constant_val, m_Constant(&added_value)))
return matchFailure();
if (!matchPattern(constant_val, m_Constant(&added_value))) return failure();
// Match Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
if (!fc_op) return matchFailure();
if (!fc_op) return failure();
// Check if the constant RHS is either 0D (scalar), or a 1D with
// `{num_channels}` shape.
@ -236,17 +235,17 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
if (constant_val_type.getRank() == 0) {
is_scalar_rhs = true;
} else if (constant_val_type.getRank() != 1) {
return matchFailure();
return failure();
}
Value filter = fc_op.filter();
Value bias = fc_op.bias();
ElementsAttr bias_value;
const bool is_none_bias = bias.getType().isa<NoneType>();
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
if (fc_op.fused_activation_function() != "NONE") return failure();
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
return matchFailure();
return failure();
// Rewrite
Location loc = fc_op.getLoc();
@ -261,7 +260,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
// Filter must be a `2D` tensor with `{num_channels, num_features}`
// shape. The following check is rejecting unknown rank (-1).
if (filter_type.getRank() != 2) {
return matchFailure();
return failure();
}
int num_channels = filter_type.getShape()[0];
@ -297,7 +296,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
/*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
/*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
return matchSuccess();
return success();
}
};
@ -305,13 +304,13 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
using OpRewritePattern<TFL::ReluOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(TFL::ReluOp relu_op,
PatternRewriter &rewriter) const override {
Operation *input = relu_op.getOperand().getDefiningOp();
if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure();
if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
auto fully_connected_op = cast<FullyConnectedOp>(input);
if (fully_connected_op.fused_activation_function() != "NONE")
return matchFailure();
return failure();
auto new_activation_func = rewriter.getStringAttr("RELU");
auto new_weights_format =
@ -323,7 +322,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
fully_connected_op.filter(), fully_connected_op.bias(),
new_activation_func, new_weights_format, new_keep_num_dims);
return matchSuccess();
return success();
}
};
@ -332,25 +331,25 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(TFL::MulOp mul_op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(TFL::MulOp mul_op,
PatternRewriter &rewriter) const override {
// Mul.
DenseElementsAttr cst;
Value constant_val = mul_op.rhs();
if (!matchPattern(constant_val, m_Constant(&cst))) return matchFailure();
if (!matchPattern(constant_val, m_Constant(&cst))) return failure();
// Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
if (!fc_op) return matchFailure();
if (!fc_op) return failure();
Value filter = fc_op.filter();
Value bias = fc_op.bias();
ElementsAttr cst_tmp;
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&cst_tmp)))
return matchFailure();
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
return failure();
if (fc_op.fused_activation_function() != "NONE") return failure();
// Broadcast the constant operand of Mul if it isn't compatible to the
// filter input. We only support broadcasting the operand along the depth
@ -365,7 +364,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
normalized_shape, cst.getType().getElementType()));
Type new_type = new_cst.getType();
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
return matchFailure();
return failure();
}
auto new_op =
rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
@ -393,7 +392,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
/*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
/*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
return matchSuccess();
return success();
}
};
@ -425,36 +424,36 @@ template <typename AffineOpType>
struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(TFL::MulOp mul_op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(TFL::MulOp mul_op,
PatternRewriter &rewriter) const override {
// Mul. Required 1-D rhs for batch normalization.
DenseElementsAttr gamma_cst;
Value gamma = mul_op.rhs();
if (!matchPattern(gamma, m_Constant(&gamma_cst))) return matchFailure();
if (gamma_cst.getType().getRank() != 1) return matchFailure();
if (!matchPattern(gamma, m_Constant(&gamma_cst))) return failure();
if (gamma_cst.getType().getRank() != 1) return failure();
// Affine op
Operation *mul_op_lhs = mul_op.lhs().getDefiningOp();
auto fc_op = dyn_cast_or_null<AffineOpType>(mul_op_lhs);
if (!fc_op) return matchFailure();
if (!fc_op) return failure();
Value filter = fc_op.filter();
Value bias = fc_op.bias();
// QDQs
auto dq_op = dyn_cast_or_null<TFL::DequantizeOp>(filter.getDefiningOp());
if (!dq_op) return matchFailure();
if (!dq_op) return failure();
auto q_op =
dyn_cast_or_null<TFL::QuantizeOp>(dq_op.input().getDefiningOp());
if (!q_op) return matchFailure();
if (!q_op) return failure();
filter = q_op.input();
// weight constant
ElementsAttr cst_tmp;
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&cst_tmp)))
return matchFailure();
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
return failure();
if (fc_op.fused_activation_function() != "NONE") return failure();
// Broadcast the constant operand of Mul if it isn't compatible to the
// filter input. We only support broadcasting the operand along the depth
@ -469,7 +468,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst);
broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
} else {
return matchFailure();
return failure();
}
// Rewrite filter constant. Since the folder of TFL::MulOp couldn't
@ -478,7 +477,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
rewriter.create<TF::MulOp>(loc, filter, broadcasted_gamma).z();
// Update the scale in the quantize op.
auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst);
if (!new_qtype) return matchFailure();
if (!new_qtype) return failure();
rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(q_op, new_qtype.getValue(),
new_filter, new_qtype);
@ -491,7 +490,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
// Remove the tailing mul op.
mul_op.replaceAllUsesWith(fc_op.getResult());
return matchSuccess();
return success();
}
};
@ -504,20 +503,19 @@ template <typename AffineOpType>
struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
using OpRewritePattern<AffineOpType>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineOpType fc_op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineOpType fc_op,
PatternRewriter &rewriter) const override {
// Binary op.
Operation *binary_op = fc_op.input().getDefiningOp();
if (!binary_op || binary_op->getNumOperands() != 2)
return this->matchFailure();
if (!binary_op || binary_op->getNumOperands() != 2) return failure();
// We only handle the cases the RHS is a scalar.
// TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that
// the constant operands are on the RHS, we need to consider LHS constant
// operand if necessary.
DenseFPElementsAttr cst;
if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst)))
return this->matchFailure();
if (cst.getNumElements() != 1) return this->matchFailure();
return failure();
if (cst.getNumElements() != 1) return failure();
APFloat cst_value = *cst.float_value_begin();
// Affine op.
@ -527,21 +525,21 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
if (!matchPattern(filter, m_Constant(&filter_cst))) {
// The filter maybe quantized, then we should set it to the real constant.
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
if (!dq) return this->matchFailure();
if (!dq) return failure();
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
return this->matchFailure();
return failure();
}
filter = q.input();
}
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&bias_cst)))
return this->matchFailure();
return failure();
ShapedType filter_type = filter_cst.getType();
if (llvm::isa<AddOp>(binary_op) || llvm::isa<SubOp>(binary_op)) {
auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
if (padding && padding.getValue() != "VALID") return this->matchFailure();
if (padding && padding.getValue() != "VALID") return failure();
// The fusion of add/sub is actually applying the following
// transformation:
@ -568,7 +566,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
bias_cst.float_value_begin(),
bias_cst.float_value_end());
} else {
return this->matchFailure();
return failure();
}
int64_t flatten_index = 0;
@ -610,9 +608,9 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
fc_op.setOperand(1, new_filter_op);
}
} else {
return this->matchFailure();
return failure();
}
return this->matchSuccess();
return success();
}
private:
@ -638,18 +636,17 @@ struct ConvertTrivialTransposeOpToReshapeOp
: public OpRewritePattern<TFL::TransposeOp> {
using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(TFL::TransposeOp transpose_op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op,
PatternRewriter &rewriter) const override {
auto input_type = transpose_op.x().getType().cast<ShapedType>();
auto output_type = transpose_op.y().getType().cast<ShapedType>();
// It's possible to know if the transformation is safe only if the input
// & output shapes are fully known and permutation is a constant.
if (!input_type.hasStaticShape() || !output_type.hasStaticShape())
return matchFailure();
return failure();
Value perm = transpose_op.perm();
DenseElementsAttr perm_values_attr;
if (!matchPattern(perm, m_Constant(&perm_values_attr)))
return matchFailure();
if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure();
auto input_shape = input_type.getShape();
SmallVector<int64_t, 8> perm_values;
@ -674,7 +671,7 @@ struct ConvertTrivialTransposeOpToReshapeOp
}
}
if (old_major_index_ordering != new_major_index_ordering) {
return matchFailure();
return failure();
}
// Rewrite.
@ -693,7 +690,7 @@ struct ConvertTrivialTransposeOpToReshapeOp
rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(
transpose_op, transpose_op.y().getType(), transpose_op.x(), new_shape);
return matchSuccess();
return success();
}
};

View File

@ -17,15 +17,15 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
@ -75,14 +75,14 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
explicit FoldIfOp(MLIRContext* context, FuncSet* inlined_funcs)
: OpRewritePattern<TF::IfOp>(context), inlined_funcs_(inlined_funcs) {}
PatternMatchResult matchAndRewrite(TF::IfOp op,
PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(TF::IfOp op,
PatternRewriter& rewriter) const override {
// This pattern is restricted to if ops in functions with exactly one block
// and therefore one terminator op. So, that function return type can be
// updated if operands' shapes change after inlining. Without this
// restriction, it would require tensor cast ops.
FuncOp parent_op = op.getParentOfType<FuncOp>();
if (parent_op.getBlocks().size() != 1) return matchFailure();
if (parent_op.getBlocks().size() != 1) return failure();
// Find the then and else branch functions.
SymbolTable table(op.getParentOfType<ModuleOp>());
@ -98,18 +98,18 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
inlined_funcs_->insert(then_branch);
inlined_funcs_->insert(else_branch);
rewriter.eraseOp(op.getOperation());
return matchSuccess();
return success();
}
// Extract the constant cond value.
DenseElementsAttr cond;
if (!matchPattern(op.cond(), m_Constant(&cond))) return matchFailure();
if (!matchPattern(op.cond(), m_Constant(&cond))) return failure();
// TODO(hinsu): Handle constants that are not scalar booleans.
auto cond_type = cond.getType().dyn_cast<RankedTensorType>();
if (!cond_type || !cond_type.getShape().equals({}) ||
!cond_type.getElementType().isInteger(/*width=*/1))
return matchFailure();
return failure();
// Identify the branch to inline.
bool cond_value = (*cond.int_value_begin()).getSExtValue();
@ -118,7 +118,7 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// Make sure that the function has exactly one block to simplify inlining.
// TFLite doesn't use control flow with blocks so functions with more than
// one blocks are not encountered in practice.
if (func.getBody().getBlocks().size() != 1) return matchFailure();
if (func.getBody().getBlocks().size() != 1) return failure();
BlockAndValueMapping mapper;
for (int i = 0, e = func.getNumArguments(); i != e; ++i)
@ -149,7 +149,7 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// of the function.
inlined_funcs_->insert(then_branch);
inlined_funcs_->insert(else_branch);
return matchSuccess();
return success();
}
private:

View File

@ -16,8 +16,8 @@ limitations under the License.
// This transformation pass applies some clean up steps after quantization.
#include "llvm/Support/Casting.h"
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"

View File

@ -23,21 +23,21 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Identifier.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"

View File

@ -22,11 +22,11 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"

View File

@ -38,17 +38,17 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
@ -121,12 +121,12 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
MLIRContext *ctx)
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(TFFakeQuantOp tf_op,
PatternRewriter &rewriter) const override {
// We don't want to insert quantize/dequantize if the quantize op exists.
auto res = tf_op.outputs();
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin()))
return this->matchFailure();
return failure();
// Extract the min/max constant values from the operands. We also consider
// a special case that there are tf.Identity ops between the min/max
@ -137,8 +137,8 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
min = id1.input();
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
max = id2.input();
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
if (!matchPattern(min, m_Constant(&min_value))) return failure();
if (!matchPattern(max, m_Constant(&max_value))) return failure();
int quant_dim = -1;
if (PerAxis) {
@ -155,7 +155,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
TypeAttr qtype = quant::GetQuantizedTypeAttr(
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
narrow_range, /*is_signed=*/false);
if (!qtype) this->matchFailure();
if (!qtype) failure();
// Finally, use the quantization parameter to create the quantize and
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
@ -168,7 +168,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
value.replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
return this->matchSuccess();
return success();
}
};
@ -208,8 +208,8 @@ struct ConvertTFConvOp : public RewritePattern {
: RewritePattern(TFConvOpType::getOperationName(), 1, context),
intAttrOne(Builder(context).getI32IntegerAttr(1)) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Assumes TensorFlow convolution op is already verified to be
// in valid form.
@ -223,10 +223,10 @@ struct ConvertTFConvOp : public RewritePattern {
TFConvOpType tf_op = cast<TFConvOpType>(op);
if (!TFTypeIsFloatTensor(tf_op.input()) || !TFDataFormatIsNHWC(op))
return matchFailure();
return failure();
IntegerAttr height, width;
if (!TFIntListIs1XY1(op, "strides", &height, &width)) return matchFailure();
if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure();
ConvertTFConvOpMatchState state;
state.stride_height = height;
@ -242,14 +242,14 @@ struct ConvertTFConvOp : public RewritePattern {
state.dilation_width_factor = intAttrOne;
}
if (!TFPaddingIsSameOrValid(op, &state.padding)) return matchFailure();
if (!TFPaddingIsSameOrValid(op, &state.padding)) return failure();
// Additionally, we require the filter operand to be of 4-D tensor type so
// that we can extract info from the shape (e.g., for constructing bias
// tensor, for setting depth_multiplier attribute, etc.).
auto filter = tf_op.filter();
auto filter_type = filter.getType().template dyn_cast<RankedTensorType>();
if (!filter_type || filter_type.getRank() != 4) return matchFailure();
if (!filter_type || filter_type.getRank() != 4) return failure();
// TensorFlow convolution op only has two inputs, while the TFLite one has
// three, with the bias vector marked as optional. However, TOCO has a
@ -274,7 +274,7 @@ struct ConvertTFConvOp : public RewritePattern {
bias);
rewriter.replaceOp(op, conv_op.getResult());
return matchSuccess();
return success();
}
const IntegerAttr intAttrOne;
@ -418,8 +418,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
explicit ConvertTFStridedSlice(MLIRContext *context)
: RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
PatternMatchResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
PatternRewriter &rewriter) const {
LogicalResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
PatternRewriter &rewriter) const {
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
// Insert a new reshape op.
@ -474,11 +474,11 @@ struct ConvertTFStridedSlice : public RewritePattern {
rewriter.getI64IntegerAttr(0),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.shrink_axis_mask()));
return matchSuccess();
return success();
}
PatternMatchResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
PatternRewriter &rewriter) const {
LogicalResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
PatternRewriter &rewriter) const {
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
DenseIntElementsAttr begin_dense_elem_attr;
@ -486,7 +486,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
if (!begin_ranked_attr_type ||
!matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
return matchFailure();
return failure();
}
DenseIntElementsAttr end_dense_elem_attr;
@ -494,7 +494,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
if (!end_ranked_attr_type ||
!matchPattern(end, m_Constant(&end_dense_elem_attr))) {
return matchFailure();
return failure();
}
DenseIntElementsAttr stride_dense_elem_attr;
@ -503,7 +503,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
stride.getType().dyn_cast<RankedTensorType>();
if (!stride_ranked_attr_type ||
!matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
return matchFailure();
return failure();
}
Value input = strided_slice_op.input();
@ -516,7 +516,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
const ArrayRef<int64_t> begin_shape = begin_type.getShape();
const int begin_dim = begin_shape.size();
if (begin_dim != 1) return matchFailure();
if (begin_dim != 1) return failure();
const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
@ -586,11 +586,11 @@ struct ConvertTFStridedSlice : public RewritePattern {
strided_slice_op.new_axis_mask()),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.shrink_axis_mask()));
return matchSuccess();
return success();
}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// TODO(renjieliu): Consider expand the transformation for shrink
// mask as well.
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
@ -606,7 +606,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
if (ellipsis_mask != 0) {
return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter);
}
return matchFailure();
return failure();
}
};

View File

@ -19,17 +19,17 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/Functional.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {

Some files were not shown because too many files have changed in this diff Show More