Merge branch 'master' into master
This commit is contained in:
commit
fc8a94750e
11
.bazelrc
11
.bazelrc
@ -238,6 +238,10 @@ build:linux --copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
|
||||
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
|
||||
# _USE_MATH_DEFINES is defined.
|
||||
build:windows --copt=/D_USE_MATH_DEFINES
|
||||
|
||||
# Default paths for TF_SYSTEM_LIBS
|
||||
build:linux --define=PREFIX=/usr
|
||||
build:linux --define=LIBDIR=$(PREFIX)/lib
|
||||
@ -258,9 +262,8 @@ build:windows --host_cxxopt=/std:c++14
|
||||
# On windows, we still link everything into a single DLL.
|
||||
build:windows --config=monolithic
|
||||
|
||||
# On linux and macos, we dynamically link small amount of kernels
|
||||
# On linux, we dynamically link small amount of kernels
|
||||
build:linux --config=dynamic_kernels
|
||||
build:macos --config=dynamic_kernels
|
||||
|
||||
# Make sure to include as little of windows.h as possible
|
||||
build:windows --copt=-DWIN32_LEAN_AND_MEAN
|
||||
@ -378,9 +381,9 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
|
||||
|
||||
build:rbe_win --config=rbe
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:toolchain"
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
|
||||
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:cc-toolchain-x64_windows"
|
||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:cc-toolchain-x64_windows"
|
||||
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
|
||||
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
||||
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
|
||||
|
@ -72,7 +72,7 @@ TensorFlow coding style.
|
||||
[tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core)
|
||||
and
|
||||
[tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python).
|
||||
TensorFlow has reached version 1 and hence cannot make
|
||||
TensorFlow has passed version 1.0 and hence cannot make
|
||||
non-backward-compatible API changes without a major release. Reviewers of
|
||||
your pull request will comment on any API compatibility issues.
|
||||
* When you contribute a new feature to TensorFlow, the maintenance burden is
|
||||
|
26
README.md
26
README.md
@ -37,18 +37,18 @@ See the [TensorFlow install guide](https://www.tensorflow.org/install) for the
|
||||
[Docker container](https://www.tensorflow.org/install/docker), and
|
||||
[build from source](https://www.tensorflow.org/install/source).
|
||||
|
||||
To install the current release for CPU-only:
|
||||
To install the current release, which includes support for
|
||||
[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu) *(Ubuntu and
|
||||
Windows)*:
|
||||
|
||||
```
|
||||
$ pip install tensorflow
|
||||
```
|
||||
|
||||
Use the GPU package for
|
||||
[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu) *(Ubuntu and
|
||||
Windows)*:
|
||||
A smaller CPU-only package is also available:
|
||||
|
||||
```
|
||||
$ pip install tensorflow-gpu
|
||||
$ pip install tensorflow-cpu
|
||||
```
|
||||
|
||||
To update TensorFlow to the latest version, add `--upgrade` flag to the above
|
||||
@ -56,7 +56,7 @@ commands.
|
||||
|
||||
*Nightly binaries are available for testing using the
|
||||
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
|
||||
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) packages on PyPi.*
|
||||
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
|
||||
|
||||
#### *Try your first TensorFlow program*
|
||||
|
||||
@ -150,17 +150,3 @@ Learn more about the
|
||||
## License
|
||||
|
||||
[Apache License 2.0](LICENSE)
|
||||
|
||||
## Feature Prioritization Survey
|
||||
|
||||
The TensorFlow team is working on building/improving features, and understands
|
||||
that it is very important to prioritize these efforts based on what TF users
|
||||
need.
|
||||
|
||||
The goal of this short, < 5min
|
||||
[survey](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad), is to help
|
||||
the TensorFlow team better understand what features to prioritize based on your
|
||||
feedback. Participation is of course optional.
|
||||
|
||||
Take the survey
|
||||
[HERE](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad).
|
||||
|
23
configure.py
23
configure.py
@ -147,14 +147,16 @@ def write_action_env_to_bazelrc(var_name, var):
|
||||
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
|
||||
|
||||
|
||||
def run_shell(cmd, allow_non_zero=False):
|
||||
def run_shell(cmd, allow_non_zero=False, stderr=None):
|
||||
if stderr is None:
|
||||
stderr = sys.stdout
|
||||
if allow_non_zero:
|
||||
try:
|
||||
output = subprocess.check_output(cmd)
|
||||
output = subprocess.check_output(cmd, stderr=stderr)
|
||||
except subprocess.CalledProcessError as e:
|
||||
output = e.output
|
||||
else:
|
||||
output = subprocess.check_output(cmd)
|
||||
output = subprocess.check_output(cmd, stderr=stderr)
|
||||
return output.decode('UTF-8').strip()
|
||||
|
||||
|
||||
@ -169,10 +171,12 @@ def get_python_path(environ_cp, python_bin_path):
|
||||
if environ_cp.get('PYTHONPATH'):
|
||||
python_paths = environ_cp.get('PYTHONPATH').split(':')
|
||||
try:
|
||||
stderr = open(os.devnull, 'wb')
|
||||
library_paths = run_shell([
|
||||
python_bin_path, '-c',
|
||||
'import site; print("\\n".join(site.getsitepackages()))'
|
||||
]).split('\n')
|
||||
],
|
||||
stderr=stderr).split('\n')
|
||||
except subprocess.CalledProcessError:
|
||||
library_paths = [
|
||||
run_shell([
|
||||
@ -1179,10 +1183,17 @@ def system_specific_test_config(env):
|
||||
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
|
||||
else:
|
||||
test_and_build_filters.append('-gpu')
|
||||
write_to_bazelrc('test --test_tag_filters=%s' %
|
||||
|
||||
# Disable tests with "v1only" tag in "v2" Bazel config, but not in "v1" config
|
||||
write_to_bazelrc('test:v1 --test_tag_filters=%s' %
|
||||
','.join(test_and_build_filters + test_only_filters))
|
||||
write_to_bazelrc('test --build_tag_filters=%s' %
|
||||
write_to_bazelrc('test:v1 --build_tag_filters=%s' %
|
||||
','.join(test_and_build_filters))
|
||||
write_to_bazelrc(
|
||||
'test:v2 --test_tag_filters=%s' %
|
||||
','.join(test_and_build_filters + test_only_filters + ['-v1only']))
|
||||
write_to_bazelrc('test:v2 --build_tag_filters=%s' %
|
||||
','.join(test_and_build_filters + ['-v1only']))
|
||||
|
||||
|
||||
def set_system_libs_flag(environ_cp):
|
||||
|
@ -860,7 +860,7 @@ gen_api_init_files(
|
||||
output_files = TENSORFLOW_API_INIT_FILES_V1,
|
||||
output_package = "tensorflow._api.v1",
|
||||
root_file_name = "v1.py",
|
||||
root_init_template = "api_template_v1.__init__.py",
|
||||
root_init_template = "$(location api_template_v1.__init__.py)",
|
||||
)
|
||||
|
||||
gen_api_init_files(
|
||||
@ -883,7 +883,7 @@ gen_api_init_files(
|
||||
output_files = TENSORFLOW_API_INIT_FILES_V2,
|
||||
output_package = "tensorflow._api.v2",
|
||||
root_file_name = "v2.py",
|
||||
root_init_template = "api_template.__init__.py",
|
||||
root_init_template = "$(location api_template.__init__.py)",
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -89,6 +89,7 @@ except ImportError:
|
||||
# Enable TF2 behaviors
|
||||
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
||||
_compat.enable_v2_behavior()
|
||||
_major_api_version = 2
|
||||
|
||||
|
||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||
@ -119,8 +120,14 @@ def _running_from_pip_package():
|
||||
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
||||
|
||||
if _running_from_pip_package():
|
||||
for _s in _site_packages_dirs:
|
||||
# TODO(gunan): Add sanity checks to loaded modules here.
|
||||
for _s in _site_packages_dirs:
|
||||
# Load first party dynamic kernels.
|
||||
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
|
||||
if _fi.file_exists(_main_dir):
|
||||
_ll.load_library(_main_dir)
|
||||
|
||||
# Load third party dynamic kernels.
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
|
@ -104,6 +104,8 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
|
||||
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||
setattr(_current_module, "flags", flags)
|
||||
|
||||
_major_api_version = 1
|
||||
|
||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||
# running under pip.
|
||||
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
|
||||
@ -132,8 +134,14 @@ def _running_from_pip_package():
|
||||
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
||||
|
||||
if _running_from_pip_package():
|
||||
for _s in _site_packages_dirs:
|
||||
# TODO(gunan): Add sanity checks to loaded modules here.
|
||||
for _s in _site_packages_dirs:
|
||||
# Load first party dynamic kernels.
|
||||
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
|
||||
if _fi.file_exists(_main_dir):
|
||||
_ll.load_library(_main_dir)
|
||||
|
||||
# Load third party dynamic kernels.
|
||||
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
|
@ -53,6 +53,20 @@ filegroup(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"c_api_internal.h",
|
||||
"tf_status_helper.h",
|
||||
"tf_status_internal.h",
|
||||
"tf_tensor_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
hdrs = [
|
||||
|
@ -88,6 +88,18 @@ tf_cuda_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = ["c_api_experimental.h"],
|
||||
|
@ -464,7 +464,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
&new_remote_device_mgr));
|
||||
remote_device_mgr = new_remote_device_mgr.get();
|
||||
} else {
|
||||
ctx->context->ClearCaches();
|
||||
ctx->context->ClearCachesAndDefaultExecutor();
|
||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||
// tensors for removed / replaced workers.
|
||||
|
||||
@ -754,7 +754,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
return list;
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); }
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
ctx->context->ClearCachesAndThreadExecutors();
|
||||
}
|
||||
|
||||
// Set server_def on the context, possibly updating it.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
@ -26,29 +27,22 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto create_or_reset =
|
||||
[&op_to_reset, &ctx, &name, &types, &raw_device_name, &status](
|
||||
bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
|
||||
if (op_to_reset) {
|
||||
status->status = op_to_reset->Reset(ctx, name, is_function, types,
|
||||
raw_device_name, inference_ctx);
|
||||
return op_to_reset;
|
||||
} else {
|
||||
TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx);
|
||||
status->status = new_op->operation.SetDeviceName(raw_device_name);
|
||||
return new_op;
|
||||
}
|
||||
};
|
||||
|
||||
if (op_to_reset && op_to_reset->ctx != ctx) {
|
||||
status->status = tensorflow::errors::Internal(
|
||||
"Cannot reset a TFE_Op from another TFE_Context");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
if (!is_function) {
|
||||
const tensorflow::OpDef* op_def;
|
||||
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return create_or_reset(false, new TFE_OpInferenceContext(op_def));
|
||||
}
|
||||
if (!ctx->context->FindFunctionByName(name)) {
|
||||
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
|
||||
} else if (!ctx->context->FindFunctionByName(name)) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"'", name,
|
||||
"' is neither a type of a primitive operation nor a name "
|
||||
@ -58,5 +52,15 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
"registered in the binary running in this process.");
|
||||
return nullptr;
|
||||
}
|
||||
return create_or_reset(true, nullptr);
|
||||
|
||||
if (op_to_reset) {
|
||||
status->status = op_to_reset->Reset(
|
||||
name, is_function, types, raw_device_name, std::move(inference_ctx));
|
||||
return op_to_reset;
|
||||
}
|
||||
|
||||
TFE_Op* new_op =
|
||||
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
|
||||
status->status = new_op->operation.SetDeviceName(raw_device_name);
|
||||
return new_op;
|
||||
}
|
||||
|
@ -125,24 +125,26 @@ struct TFE_OpInferenceContext {
|
||||
struct TFE_Op {
|
||||
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
TFE_OpInferenceContext* inference_ctx)
|
||||
: operation(ctx->context, op, is_function, t),
|
||||
inference_ctx(inference_ctx) {}
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
|
||||
: ctx(ctx),
|
||||
operation(ctx->context, op, is_function, t),
|
||||
inference_ctx(std::move(inference_ctx)) {}
|
||||
|
||||
void Clear() {
|
||||
operation.Clear();
|
||||
inference_ctx.reset();
|
||||
}
|
||||
|
||||
tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function,
|
||||
tensorflow::Status Reset(const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
const char* raw_device_name,
|
||||
TFE_OpInferenceContext* infer_ctx) {
|
||||
inference_ctx.reset(infer_ctx);
|
||||
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
|
||||
inference_ctx = std::move(infer_ctx);
|
||||
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
TFE_Context* ctx;
|
||||
tensorflow::EagerOperation operation;
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
};
|
||||
|
@ -233,6 +233,7 @@ cc_library_with_android_deps(
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_experimental",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
@ -127,6 +127,33 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
|
||||
target_node_names, outputs, run_metadata);
|
||||
}
|
||||
|
||||
Status ClientSession::Run(
|
||||
const RunOptions& run_options, const FeedType& inputs,
|
||||
const std::vector<Output>& fetch_outputs,
|
||||
const std::vector<Operation>& run_outputs, std::vector<Tensor>* outputs,
|
||||
RunMetadata* run_metadata,
|
||||
const thread::ThreadPoolOptions& threadpool_options) const {
|
||||
std::vector<std::pair<string, Tensor>> feeds;
|
||||
for (auto const& feed : inputs) {
|
||||
TF_RETURN_IF_ERROR(feed.second.status);
|
||||
feeds.emplace_back(feed.first.name(), feed.second.tensor);
|
||||
}
|
||||
std::vector<string> output_tensor_names;
|
||||
output_tensor_names.reserve(fetch_outputs.size());
|
||||
for (auto const& output : fetch_outputs) {
|
||||
output_tensor_names.push_back(output.name());
|
||||
}
|
||||
std::vector<string> target_node_names;
|
||||
target_node_names.reserve(run_outputs.size());
|
||||
for (auto const& output : run_outputs) {
|
||||
target_node_names.push_back(output.node()->name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
|
||||
return impl()->session_->Run(run_options, feeds, output_tensor_names,
|
||||
target_node_names, outputs, run_metadata,
|
||||
threadpool_options);
|
||||
}
|
||||
|
||||
Status ClientSession::MakeCallable(const CallableOptions& callable_options,
|
||||
CallableHandle* out_handle) {
|
||||
TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
|
||||
|
@ -93,6 +93,14 @@ class ClientSession {
|
||||
const std::vector<Operation>& run_outputs,
|
||||
std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
|
||||
|
||||
/// Same as above. Additionally allows user to provide custom threadpool
|
||||
/// implementation via ThreadPoolOptions.
|
||||
Status Run(const RunOptions& run_options, const FeedType& inputs,
|
||||
const std::vector<Output>& fetch_outputs,
|
||||
const std::vector<Operation>& run_outputs,
|
||||
std::vector<Tensor>* outputs, RunMetadata* run_metadata,
|
||||
const thread::ThreadPoolOptions& threadpool_options) const;
|
||||
|
||||
/// \brief A handle to a subgraph, created with
|
||||
/// `ClientSession::MakeCallable()`.
|
||||
typedef int64 CallableHandle;
|
||||
|
@ -112,7 +112,7 @@ TEST(ClientSessionTest, Extend) {
|
||||
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({31, 42}, {2}));
|
||||
}
|
||||
|
||||
TEST(ClientSessionTest, MultiThreaded) {
|
||||
TEST(ClientSessionTest, MultiThreadedWithDefaultThreadpool) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto a = Add(root, {1, 2}, {3, 4});
|
||||
auto b = Mul(root, {1, 2}, {3, 4});
|
||||
@ -138,6 +138,49 @@ TEST(ClientSessionTest, MultiThreaded) {
|
||||
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
|
||||
}
|
||||
|
||||
TEST(ClientSessionTest, MultiThreadedWithCustomThreadpool) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
int num_threads = 3;
|
||||
auto a = Add(root, {1, 2}, {3, 4});
|
||||
auto b = Mul(root, {1, 2}, {3, 4});
|
||||
ClientSession session(root);
|
||||
|
||||
auto inter_op_threadpool =
|
||||
absl::make_unique<CustomThreadPoolImpl>(num_threads);
|
||||
ASSERT_EQ(inter_op_threadpool->GetNumScheduleCalled(), 0);
|
||||
|
||||
auto intra_op_threadpool =
|
||||
absl::make_unique<CustomThreadPoolImpl>(num_threads);
|
||||
ASSERT_EQ(intra_op_threadpool->GetNumScheduleCalled(), 0);
|
||||
|
||||
tensorflow::thread::ThreadPoolOptions threadPoolOptions;
|
||||
threadPoolOptions.inter_op_threadpool = inter_op_threadpool.get();
|
||||
threadPoolOptions.intra_op_threadpool = intra_op_threadpool.get();
|
||||
|
||||
{
|
||||
thread::ThreadPool thread_pool(Env::Default(), "pool", 2);
|
||||
thread_pool.Schedule([&session, a]() {
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {a}, {},
|
||||
&outputs, nullptr, thread::ThreadPoolOptions()));
|
||||
test::ExpectTensorEqual<int>(outputs[0],
|
||||
test::AsTensor<int>({4, 6}, {2}));
|
||||
});
|
||||
thread_pool.Schedule([&session, b]() {
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {b}, {},
|
||||
&outputs, nullptr, thread::ThreadPoolOptions()));
|
||||
test::ExpectTensorEqual<int>(outputs[0],
|
||||
test::AsTensor<int>({3, 8}, {2}));
|
||||
});
|
||||
}
|
||||
auto c = Sub(root, b, a);
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {c}, {},
|
||||
&outputs, nullptr, thread::ThreadPoolOptions()));
|
||||
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
|
||||
}
|
||||
|
||||
TEST(ClientSessionTest, CallableWithDefaultThreadPool) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto a = Placeholder(root, DT_INT32);
|
||||
|
@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define _USE_MATH_DEFINES
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||
|
@ -125,11 +125,11 @@ cc_library(
|
||||
deps = [
|
||||
":constants",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/util/tensor_bundle",
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/util/tensor_bundle",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -75,8 +75,8 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support", # fixdeps: keep
|
||||
"@llvm//:x86_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:support", # fixdeps: keep
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
@ -104,11 +104,11 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:aarch64_code_gen", # fixdeps: keep
|
||||
"@llvm//:arm_code_gen", # fixdeps: keep
|
||||
"@llvm//:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm//:target",
|
||||
"@llvm//:x86_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
@ -205,9 +205,9 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm//:core",
|
||||
"@llvm//:support",
|
||||
"@llvm//:target",
|
||||
"@llvm-project//llvm:core",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:target",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -407,6 +407,7 @@ def target_llvm_triple():
|
||||
"//tensorflow:android_arm64": "aarch64-none-android",
|
||||
"//tensorflow:android_x86": "i686-none-android",
|
||||
"//tensorflow:ios": "arm64-none-ios",
|
||||
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
|
||||
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
||||
"//tensorflow:macos": "x86_64-none-darwin",
|
||||
"//conditions:default": "x86_64-pc-linux",
|
||||
|
@ -500,6 +500,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,8 +22,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and
|
||||
// executes (using XLA) TF function calls marked with "_XlaCompiledKernel".
|
||||
// Replaces TF function calls marked with `_XlaCompiledKernel` with _XlaCompile
|
||||
// and _XlaRun nodes (which compile and launch, respectively, the corresponding
|
||||
// HLO module).
|
||||
class BuildXlaOpsPass : public GraphOptimizationPass {
|
||||
public:
|
||||
// If enable_lazy_compilation is not nullopt then *enable_lazy_compilation
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kXlaMustCompileAttr = "_XlaMustCompile";
|
||||
|
||||
const char* const kXlaCompileAttr = "_XlaCompile";
|
||||
|
||||
// User-provided through jit_scope APIs. Effective only when auto_jit is OFF.
|
||||
|
@ -22,7 +22,16 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
// Name of attribute used to tag operators for compilation with XLA
|
||||
|
||||
// Implies must-compile semantics: either it will be compiled
|
||||
// with XLA, or an error will be thrown.
|
||||
extern const char* const kXlaMustCompileAttr; // "_XlaMustCompile"
|
||||
|
||||
// Implies auto-clustering: tagged nodes will be clustered and compiled with XLA
|
||||
// on a best-effort basis.
|
||||
extern const char* const kXlaCompileAttr; // "_XlaCompile"
|
||||
|
||||
// Implies auto-clustering within the given scope.
|
||||
extern const char* const kXlaScopeAttr; // "_XlaScope"
|
||||
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"
|
||||
|
||||
|
@ -27,6 +27,15 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// EncapsulateSubgraphs pass takes all the nodes with the same cluster ID
|
||||
// (derived from kXlaClusterAttr=ID (kXlaClusterAttr) attribute), puts them into
|
||||
// a TF function, and replaces the subgraph in the main graph with a call to
|
||||
// that TF function annotated with kXlaCompiledKernelAttr (_XlaCompiledKernel).
|
||||
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
// A rewriting function to apply to each subgraph during encapsulation.
|
||||
// 'arg_source_tensors' are the tensors corresponding to the arguments in the
|
||||
// original source graph (*not* 'graph').
|
||||
@ -100,11 +109,6 @@ extern const char* const kXlaHasReferenceVarsAttr;
|
||||
// TODO(hpucha): Move the utilities to a more appropriate place.
|
||||
void SortControlInputs(GraphDef* gdef);
|
||||
|
||||
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
|
||||
|
@ -2130,6 +2130,53 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CopyOutsideCompilationConstNodes(
|
||||
Graph* g, const string& outside_compilation_attr_name) {
|
||||
for (Node* n : g->op_nodes()) {
|
||||
if (!n->IsConstant() ||
|
||||
!HasNodeAttr(n->def(), outside_compilation_attr_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<const Edge*> out_edges(n->out_edges().begin(),
|
||||
n->out_edges().end());
|
||||
bool has_non_oc_output = false;
|
||||
for (const Edge* e : out_edges) {
|
||||
if (!e->IsControlEdge() &&
|
||||
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
|
||||
has_non_oc_output = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!has_non_oc_output) {
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeDef copy_def = n->def();
|
||||
copy_def.set_name(g->NewName(n->name()));
|
||||
copy_def.mutable_attr()->erase(outside_compilation_attr_name);
|
||||
Status s;
|
||||
Node* copy_node = g->AddNode(copy_def, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
g->AddControlEdge(e->src(), copy_node);
|
||||
}
|
||||
}
|
||||
for (const Edge* e : out_edges) {
|
||||
if (!e->IsControlEdge() &&
|
||||
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
|
||||
Node* dst = e->dst();
|
||||
int dst_input = e->dst_input();
|
||||
g->RemoveEdge(e);
|
||||
g->AddEdge(copy_node, 0, dst, dst_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status RewriteOutsideCompilationSubgraphFn::operator()(
|
||||
@ -2279,6 +2326,10 @@ Status ExtractOutsideCompilationForFunction(
|
||||
std::vector<string> outside_compilation_host_graphs;
|
||||
std::vector<string> shape_inference_graphs_to_rewrite;
|
||||
if (*has_outside_compilation) {
|
||||
// Copy outside compilation Const nodes with non outside compilation users.
|
||||
TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
|
||||
fbody->graph, outside_compilation_attr_name));
|
||||
|
||||
// Find dependencies between outside compilation clusters.
|
||||
TF_ASSIGN_OR_RETURN(auto cluster_deps,
|
||||
OutsideCompilationClusterDependencies(
|
||||
|
@ -1187,7 +1187,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
}
|
||||
|
||||
if (!whitelist.empty() && !whitelist.contains(node->def().op())) {
|
||||
VLOG(1) << "Rejecting " << node->name()
|
||||
VLOG(1) << "Rejecting TF operation " << node->def().op()
|
||||
<< " as it is not listed in --tf_xla_ops_to_cluster.";
|
||||
continue;
|
||||
}
|
||||
@ -1770,9 +1770,10 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
{"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
|
||||
"Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1",
|
||||
"Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log",
|
||||
"Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", "Rsqrt",
|
||||
"Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", "Square",
|
||||
"Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Lgamma", "Digamma",
|
||||
"Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round",
|
||||
"Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
|
||||
"Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv",
|
||||
"Lgamma", "Digamma",
|
||||
// Binary
|
||||
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
|
||||
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd",
|
||||
@ -2035,6 +2036,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"XlaDynamicSlice",
|
||||
"XlaDynamicUpdateSlice",
|
||||
"XlaEinsum",
|
||||
"XlaGather",
|
||||
"XlaIf",
|
||||
"XlaKeyValueSort",
|
||||
"XlaPad",
|
||||
@ -2042,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"XlaReduce",
|
||||
"XlaReduceWindow",
|
||||
"XlaReplicaId",
|
||||
"XlaScatter",
|
||||
"XlaSelectAndScatter",
|
||||
"XlaSelfAdjointEig",
|
||||
"XlaSend",
|
||||
|
@ -34,8 +34,9 @@ extern const char* const kXlaClusterAttr;
|
||||
// compilation by the encapsulate subgraphs pass.
|
||||
extern const char* const kXlaOutsideCompilationAttr;
|
||||
|
||||
// Pass that marks a subset of operators in the graph with attribute
|
||||
// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass.
|
||||
// Marks a subset of nodes in the graph which are to be clustered
|
||||
// with an attribute _XlaCluster=<cluster id> so they are picked up by the
|
||||
// EncapsulateSubgraphsPass.
|
||||
class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
public:
|
||||
MarkForCompilationPass() = default;
|
||||
|
@ -17,7 +17,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
@ -39,7 +42,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
|
||||
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
|
||||
}
|
||||
|
||||
Status PropagateShapes(const Graph& graph,
|
||||
Status PropagateShapes(Graph* graph,
|
||||
const std::map<int, InferredShape>& arg_shapes,
|
||||
const std::vector<BackEdgeHelper::BackEdge>& back_edges,
|
||||
ShapeRefiner* shape_refiner) {
|
||||
@ -54,7 +57,7 @@ Status PropagateShapes(const Graph& graph,
|
||||
// shapes.
|
||||
// TODO(phawkins): handle cyclic graphs.
|
||||
std::vector<Node*> order;
|
||||
GetReversePostOrder(graph, &order);
|
||||
GetReversePostOrder(*graph, &order);
|
||||
|
||||
for (Node* n : order) {
|
||||
// Ignore the status returned by the shape_refiner. We want the best effort
|
||||
@ -99,6 +102,67 @@ Status PropagateShapes(const Graph& graph,
|
||||
}
|
||||
}
|
||||
|
||||
// Sometimes we have VariableShape nodes in while loop (after Enter nodes).
|
||||
// They won't be constant-folded because TensorFlow constant folding does
|
||||
// not handle Enter nodes (and thus does not handle any nodes after Enter
|
||||
// nodes). We try to replace such VariableShape nodes with Const nodes here.
|
||||
if (n->type_string() == "VariableShape") {
|
||||
shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
|
||||
auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
|
||||
if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
|
||||
shape_inference::ShapeHandle handle =
|
||||
handle_shapes_and_types->at(0).shape;
|
||||
TensorShapeProto shape_proto;
|
||||
context->ShapeHandleToProto(handle, &shape_proto);
|
||||
if (!shape_proto.unknown_rank()) {
|
||||
NodeDef const_def;
|
||||
const_def.set_op("Const");
|
||||
Node* var_node;
|
||||
TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
|
||||
const_def.set_name(
|
||||
graph->NewName(absl::StrCat("var_shape_", var_node->name())));
|
||||
DataType dtype = n->output_type(0);
|
||||
AddNodeAttr("dtype", dtype, &const_def);
|
||||
TensorProto value;
|
||||
value.set_dtype(dtype);
|
||||
value.mutable_tensor_shape()->add_dim()->set_size(
|
||||
shape_proto.dim_size());
|
||||
for (const auto& dim : shape_proto.dim()) {
|
||||
if (dtype == DT_INT32) {
|
||||
value.add_int_val(dim.size());
|
||||
} else {
|
||||
value.add_int64_val(dim.size());
|
||||
}
|
||||
}
|
||||
AddNodeAttr("value", value, &const_def);
|
||||
for (auto const& attr : n->attrs()) {
|
||||
if (*attr.first.begin() == '_') {
|
||||
AddNodeAttr(attr.first, attr.second, &const_def);
|
||||
}
|
||||
}
|
||||
|
||||
Status s;
|
||||
Node* const_node = graph->AddNode(const_def, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
|
||||
graph->AddControlEdge(var_node, const_node);
|
||||
std::vector<const Edge*> out_edges(n->out_edges().begin(),
|
||||
n->out_edges().end());
|
||||
for (const Edge* e : out_edges) {
|
||||
if (e->IsControlEdge()) {
|
||||
graph->AddControlEdge(const_node, e->dst());
|
||||
graph->RemoveEdge(e);
|
||||
} else {
|
||||
Node* dst = e->dst();
|
||||
int dst_input = e->dst_input();
|
||||
graph->RemoveEdge(e);
|
||||
graph->AddEdge(const_node, 0, dst, dst_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge node causes a loop so we remove NextIteration->Merge edge before
|
||||
// performing shape inference. But removing those edges also prevents us
|
||||
// from inferring output shape for Merge node (we need shapes for all its
|
||||
@ -196,7 +260,7 @@ Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
|
||||
// the shape inference is complete.
|
||||
BackEdgeHelper back_edge;
|
||||
TF_RETURN_IF_ERROR(back_edge.Remove(graph));
|
||||
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes,
|
||||
TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
|
||||
back_edge.RemovedEdges(), &shape_refiner));
|
||||
TF_RETURN_IF_ERROR(back_edge.Replace());
|
||||
|
||||
|
@ -191,7 +191,7 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
|
||||
data::IteratorGetNextAsOptionalOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \
|
||||
data::IteratorGetNextSyncOp); \
|
||||
data::IteratorGetNextOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("string_handle"), \
|
||||
|
@ -21,7 +21,7 @@ namespace tensorflow {
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const {
|
||||
return CanCreateXlaKernel(flr, node_def);
|
||||
return CanCreateXlaKernel(node_def);
|
||||
}
|
||||
|
||||
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
||||
|
@ -95,15 +95,17 @@ AttrValue BoolAttr(bool b) {
|
||||
|
||||
TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
|
||||
FunctionDef fdef = XTimesY();
|
||||
(*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
|
||||
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(
|
||||
flr_, ToNodeDef(R"pb(
|
||||
NodeDef callsite =
|
||||
ToNodeDef(R"pb(
|
||||
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
|
||||
)pb"),
|
||||
&kernel_);
|
||||
)pb");
|
||||
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
|
||||
// Note: need to set attribute on the created node.
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
|
||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||
|
||||
EXPECT_EQ("XTimesY", kernel_->name());
|
||||
@ -137,7 +139,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
|
||||
|
||||
TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
||||
FunctionDef fdef = XTimesY();
|
||||
(*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
|
||||
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(false);
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
|
@ -23,7 +23,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
@ -68,40 +70,10 @@ class SinglePassSearch {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) {
|
||||
const FunctionDef* function_def =
|
||||
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
|
||||
if (function_def == nullptr) {
|
||||
// The node def is not calling a function. Individual ops can be
|
||||
// run directly using on-demand mode, no need to create XlaLaunch
|
||||
// kernel for them.
|
||||
return false;
|
||||
}
|
||||
|
||||
// If kXlaCompileAttr is set on the node_def, use its value.
|
||||
const auto& it = node_def.attr().find(kXlaCompileAttr);
|
||||
if (it != node_def.attr().end()) {
|
||||
return it->second.b();
|
||||
}
|
||||
|
||||
// kXlaCompileAttr is not set on node_def, check if it is set on
|
||||
// FunctionDef.
|
||||
bool xla_compile = false;
|
||||
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
|
||||
node_def, kXlaCompileAttr, &xla_compile);
|
||||
if (!status.ok() || !xla_compile) {
|
||||
if (VLOG_IS_ON(3)) {
|
||||
if (!status.ok()) {
|
||||
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
|
||||
<< node_def.op() << ". status=" << status.ToString();
|
||||
} else {
|
||||
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def) {
|
||||
// If kXlaMustCompileAttr is set on the node_def, use its value.
|
||||
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
|
||||
return it != node_def.attr().end() && it->second.b();
|
||||
}
|
||||
|
||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||
@ -118,8 +90,11 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If node_def is not instantiable, e.g., the function does not exist,
|
||||
// simply bail out.
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
|
||||
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
|
||||
*fbody = flr->GetFunctionBody(handle);
|
||||
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
||||
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
||||
@ -149,7 +124,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
|
||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) {
|
||||
if (!CanCreateXlaKernel(*flr, node_def)) {
|
||||
if (!CanCreateXlaKernel(node_def)) {
|
||||
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
|
||||
}
|
||||
|
||||
@ -241,9 +216,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
|
||||
// Create the kernel.
|
||||
NameAttrList function;
|
||||
function.set_name(node_def.op());
|
||||
*(function.mutable_attr()) = node_def.attr();
|
||||
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
|
@ -24,11 +24,9 @@ namespace tensorflow {
|
||||
class FunctionLibraryRuntime;
|
||||
class OpKernel;
|
||||
|
||||
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
||||
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
||||
// with the kXlaCompileAttr set.
|
||||
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def);
|
||||
// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
|
||||
// set.
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def);
|
||||
|
||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
|
@ -6,7 +6,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/tf2xla:__subpackages__",
|
||||
"@local_config_mlir//:friends",
|
||||
"@llvm-project//mlir:friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
@ -30,8 +30,8 @@ cc_library(
|
||||
hdrs = ["op_or_arg_name_mapper.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
@ -43,11 +43,11 @@ cc_library(
|
||||
":passes",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:MlirOptLib",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//test:TestTransforms",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir/test:TestTransforms",
|
||||
],
|
||||
)
|
||||
|
||||
@ -80,9 +80,9 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
"@local_config_mlir//:AffineDialectRegistration",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:AffineDialectRegistration",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
],
|
||||
)
|
||||
|
||||
@ -92,7 +92,7 @@ cc_library(
|
||||
hdrs = ["init_mlir.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -122,11 +122,11 @@ tf_cc_binary(
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:TranslateClParser",
|
||||
"@local_config_mlir//:Translation",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TranslateClParser",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
# MLIR dialects and utilities for TensorFlow, TensorFlow Lite and XLA.
|
||||
|
||||
This module contains the MLIR
|
||||
([Multi-Level Intermediate Representation](https://github.com/tensorflow/mlir))
|
||||
([Multi-Level Intermediate Representation](https://mlir.llvm.org))
|
||||
dialects and utilities for
|
||||
|
||||
1. TensorFlow
|
||||
2. XLA
|
||||
3. TF Lite
|
||||
|
||||
See [MLIR repo](https://github.com/tensorflow/mlir) for complete documentation.
|
||||
See [MLIR's website](https://mlir.llvm.org) for complete documentation.
|
||||
|
@ -10,7 +10,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
|
||||
# Default values used by the test runner.
|
||||
_default_test_file_exts = ["mlir", ".pbtxt", ".td"]
|
||||
_default_driver = "@local_config_mlir//:run_lit.sh"
|
||||
_default_driver = "@llvm-project//mlir:run_lit.sh"
|
||||
_default_size = "small"
|
||||
_default_tags = ["no_rocm"]
|
||||
|
||||
@ -50,16 +50,16 @@ def _run_lit_test(name, data, size, tags, driver, features):
|
||||
|
||||
native.py_test(
|
||||
name = name,
|
||||
srcs = ["@llvm//:lit"],
|
||||
srcs = ["@llvm-project//llvm:lit"],
|
||||
tags = tags,
|
||||
args = [
|
||||
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
|
||||
] + features,
|
||||
data = data + [
|
||||
"//tensorflow/compiler/mlir:litfiles",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:count",
|
||||
"@llvm//:not",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:count",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
size = size,
|
||||
main = "lit.py",
|
||||
|
@ -1,6 +1,6 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
|
||||
load(
|
||||
"@local_config_mlir//:tblgen.bzl",
|
||||
"//third_party/mlir:tblgen.bzl",
|
||||
"gentbl",
|
||||
)
|
||||
|
||||
@ -8,13 +8,14 @@ package(
|
||||
default_visibility = [
|
||||
# TODO(jpienaar): Make the visibility more restrictive.
|
||||
":friends",
|
||||
"//tensorflow/lite/experimental/tf_runtime:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["@local_config_mlir//:subpackages"],
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//learning/brain/google/xla/...",
|
||||
@ -27,7 +28,7 @@ filegroup(
|
||||
srcs = [
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@local_config_mlir//:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
@ -47,7 +48,7 @@ gentbl(
|
||||
"g3doc/tfl_ops.md",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
@ -62,11 +63,11 @@ gentbl(
|
||||
"transforms/generated_prepare_tf.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/prepare_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_optimize_td_files",
|
||||
],
|
||||
@ -80,11 +81,11 @@ gentbl(
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/tensorlist_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
@ -97,11 +98,11 @@ gentbl(
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/legalize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
@ -114,11 +115,12 @@ gentbl(
|
||||
"transforms/generated_optimize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/optimize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
@ -130,11 +132,11 @@ gentbl(
|
||||
"transforms/generated_quantize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/quantize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
@ -146,11 +148,11 @@ gentbl(
|
||||
"transforms/generated_post_quantize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/post_quantize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
@ -163,9 +165,9 @@ cc_library(
|
||||
"utils/validators.h",
|
||||
],
|
||||
deps = [
|
||||
"@local_config_mlir//:Dialect",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
)
|
||||
|
||||
@ -183,21 +185,21 @@ cc_library(
|
||||
"transforms/passes.h",
|
||||
"utils/attribute_utils.h",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
||||
"@local_config_mlir//:include/mlir/Transforms/InliningUtils.h",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:Dialect",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -214,10 +216,10 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -231,9 +233,9 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
)
|
||||
|
||||
@ -246,10 +248,10 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -290,14 +292,14 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Transforms",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -315,12 +317,12 @@ cc_library(
|
||||
":tensorflow_lite",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -346,13 +348,13 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -374,7 +376,7 @@ genrule(
|
||||
"utils/generated_op_quant_spec_getters.inc",
|
||||
],
|
||||
cmd = ("$(location //tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen) " +
|
||||
"-I external/local_config_mlir/include " +
|
||||
"-I external/llvm-project/mlir/include " +
|
||||
"-I external/org_tensorflow " +
|
||||
"$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"),
|
||||
tools = ["//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen"],
|
||||
@ -388,7 +390,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@local_config_mlir//:IR",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -399,9 +401,9 @@ tf_native_cc_binary(
|
||||
"operator_converter_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
"@llvm//:tablegen",
|
||||
"@local_config_mlir//:TableGen",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
)
|
||||
|
||||
@ -434,12 +436,17 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:TransformUtils",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
)
|
||||
|
||||
@ -462,7 +469,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/core/api",
|
||||
"@local_config_mlir//:IR",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
@ -507,14 +514,14 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"@local_config_mlir//:StandardDialectRegistration",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Translation",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:StandardDialectRegistration",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -523,7 +530,7 @@ tf_cc_binary(
|
||||
name = "flatbuffer_translate",
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
"@local_config_mlir//:MlirTranslateMain",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
],
|
||||
)
|
||||
|
||||
@ -536,7 +543,7 @@ cc_library(
|
||||
"tf_tfl_translate_cl.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -548,7 +555,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -576,9 +583,9 @@ tf_cc_binary(
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -589,16 +596,15 @@ tf_cc_binary(
|
||||
":flatbuffer_translate_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform/default/build_config:base",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Parser",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -621,12 +627,12 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"@local_config_mlir//:Transforms",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
@ -653,15 +659,15 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite/tools/optimize:quantize_weights",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Parser",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Transforms",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <cstdarg>
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@ -43,24 +44,24 @@ limitations under the License.
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Translation.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Translation.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
|
||||
@ -103,12 +104,26 @@ using llvm::cl::opt;
|
||||
// Commandline flag to enable the control of flatbuffer import.
|
||||
bool use_external_constant;
|
||||
|
||||
// Commandline flag to enable graph pruning.
|
||||
bool experimental_prune_unreachable_nodes_unconditionally;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool, true> use_external_constant_flag(
|
||||
"use-external-constant",
|
||||
llvm::cl::desc("Use external constant during flatbuffer import"),
|
||||
llvm::cl::location(use_external_constant), llvm::cl::init(false));
|
||||
|
||||
// TODO(b/147111261): After the importer supports generic custom ops, we should
|
||||
// change the flag to a more lightwise flag, e.g.
|
||||
// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
|
||||
// the operations.
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
|
||||
"experimental-prune-unreachable-nodes-unconditionally",
|
||||
llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
|
||||
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
|
||||
llvm::cl::init(false));
|
||||
|
||||
namespace {
|
||||
bool IsScalar(const TensorT& tensor) {
|
||||
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
||||
@ -212,12 +227,12 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
|
||||
// type, thus none stats op is required and nullptr is retruned.
|
||||
// If the min max information is invalid, nullptr is returned.
|
||||
mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
|
||||
Value* res) {
|
||||
Value res) {
|
||||
// If the `tensor` has scale/zero_point, it must have been quantized, then the
|
||||
// min/max stats is just for comments, so ignore it.
|
||||
if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
|
||||
// If the result isn't float and unquantizable, the min/max is ignored.
|
||||
if (!res->getType()
|
||||
if (!res.getType()
|
||||
.cast<mlir::ShapedType>()
|
||||
.getElementType()
|
||||
.isa<mlir::FloatType>()) {
|
||||
@ -255,10 +270,23 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
|
||||
}
|
||||
|
||||
StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
|
||||
// TODO(krzysd) Support custom ops
|
||||
// TODO(b/143872630): Support custom ops
|
||||
if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||
return errors::Unimplemented("unsupported custom operation: ",
|
||||
opcode.custom_code);
|
||||
// Adding some custom op supported on GPU.
|
||||
const absl::string_view custom_name = opcode.custom_code;
|
||||
if (custom_name == "MaxPoolingWithArgmax2D") {
|
||||
return std::string("tfl.max_pooling_with_argmax_2d");
|
||||
}
|
||||
if (custom_name == "Convolution2DTransposeBias") {
|
||||
return std::string("tfl.convolution_2d_transpose_bias");
|
||||
}
|
||||
if (custom_name == "MaxUnpooling2D") {
|
||||
return std::string("tfl.max_unpooling_2d");
|
||||
}
|
||||
// Use an unsupported op name instead of throwing an error here in case the
|
||||
// op is pruned during the import.
|
||||
return std::string(
|
||||
llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str());
|
||||
}
|
||||
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
|
||||
return std::string("tf.If");
|
||||
@ -495,14 +523,21 @@ bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if this is a custom op.
|
||||
bool IsCustomOp(const std::string& op_name) {
|
||||
return op_name == "tfl.max_pooling_with_argmax_2d" ||
|
||||
op_name == "tfl.max_unpooling_2d" ||
|
||||
op_name == "tfl.convolution_2d_transpose_bias";
|
||||
}
|
||||
|
||||
// TODO(krzysd) Handle function calls
|
||||
StatusOr<Operation*> ConvertOp(
|
||||
const tflite::OperatorT& op, const std::vector<Value*>& vals_map,
|
||||
Value* optional_arg_marker, const std::vector<std::string>& op_names,
|
||||
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
|
||||
Value optional_arg_marker, const std::vector<std::string>& op_names,
|
||||
const std::vector<std::string>& func_names,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
|
||||
OpBuilder builder) {
|
||||
llvm::SmallVector<Value*, 4> operands;
|
||||
llvm::SmallVector<Value, 4> operands;
|
||||
llvm::SmallVector<mlir::Type, 2> outputTypes;
|
||||
|
||||
if (op.outputs.empty()) {
|
||||
@ -557,7 +592,15 @@ StatusOr<Operation*> ConvertOp(
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
||||
if (IsCustomOp(op_name)) {
|
||||
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
|
||||
builder, loc, &attrs);
|
||||
if (!status.ok()) {
|
||||
return emitError(loc, status.ToString()), status;
|
||||
}
|
||||
} else {
|
||||
mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
|
||||
}
|
||||
op_state.addAttributes(attrs);
|
||||
|
||||
// Handle the conversion from subgraph index to functions for If and While
|
||||
@ -619,6 +662,49 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
|
||||
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
||||
}
|
||||
|
||||
// Given a list of output indices, traverses the subgraph and returns the set of
|
||||
// ops that are ancestors of the output tensors.
|
||||
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
|
||||
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
|
||||
// Create a map from tensor index to defining op.
|
||||
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
|
||||
for (const auto& op : subgraph.operators) {
|
||||
for (int32_t output : op->outputs) {
|
||||
defining_op[output] = op.get();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const tflite::OperatorT*> queue;
|
||||
for (int32_t output : output_indices) {
|
||||
if (auto& op = defining_op[output]) {
|
||||
queue.push_back(op);
|
||||
} else {
|
||||
return errors::InvalidArgument("Output tensor doesn't have defining op");
|
||||
}
|
||||
}
|
||||
|
||||
// Traverse the graph towards inputs.
|
||||
absl::flat_hash_set<const tflite::OperatorT*> visited;
|
||||
while (!queue.empty()) {
|
||||
const tflite::OperatorT* op = queue.back();
|
||||
queue.pop_back();
|
||||
if (!visited.insert(op).second) {
|
||||
// The node has already been visited.
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int32_t input : op->inputs) {
|
||||
// Input tensor may not have a defining op in case it is a subgraph input
|
||||
// or a constant tensor.
|
||||
if (auto& op = defining_op[input]) {
|
||||
queue.push_back(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return visited;
|
||||
}
|
||||
|
||||
// Build a FuncOp from a tflite SubGraph
|
||||
// The op_names are a mapping from indexes into the TFLite operators array to
|
||||
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
|
||||
@ -635,7 +721,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
||||
Location base_loc, Builder builder,
|
||||
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
|
||||
bool use_external_constant) {
|
||||
bool use_external_constant,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
llvm::SmallVector<mlir::Type, 2> ret_types;
|
||||
llvm::SmallVector<mlir::Type, 4> input_types;
|
||||
|
||||
@ -692,19 +779,19 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
auto& body = func.getBody();
|
||||
OpBuilder op_builder{body};
|
||||
|
||||
std::vector<Value*> vals_map(subgraph.tensors.size(), nullptr);
|
||||
Value* maybe_optional_arg_marker = nullptr;
|
||||
std::vector<Value> vals_map(subgraph.tensors.size(), nullptr);
|
||||
Value maybe_optional_arg_marker = nullptr;
|
||||
|
||||
// Get or construct MLIR values for each input
|
||||
for (int i = 0, e = subgraph.inputs.size(); i < e; i++) {
|
||||
auto input_tensor = subgraph.inputs[i];
|
||||
const auto& tensor = *subgraph.tensors.at(input_tensor);
|
||||
auto loc = TensorLoc(tensor, builder, base_loc);
|
||||
if (nullptr != vals_map[input_tensor]) {
|
||||
if (vals_map[input_tensor]) {
|
||||
auto err = errors::FailedPrecondition("duplicate input arguments");
|
||||
return emitError(loc, err.ToString()), err;
|
||||
}
|
||||
Value* input_value = func.getArgument(i);
|
||||
Value input_value = func.getArgument(i);
|
||||
|
||||
// If the `tensor` has min/max and doesn't have scale/zero_point
|
||||
// information, a stats op is created to use the input_value, then the
|
||||
@ -731,8 +818,19 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
|
||||
}
|
||||
|
||||
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||
if (experimental_prune_unreachable_nodes_unconditionally) {
|
||||
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
|
||||
PruneSubgraph(subgraph, func_outputs));
|
||||
}
|
||||
|
||||
// Construct MLIR operators from TFLite operators
|
||||
for (auto& op : subgraph.operators) {
|
||||
if (experimental_prune_unreachable_nodes_unconditionally &&
|
||||
!pruned_subgraph_ops.contains(op)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto input_num : op->inputs) {
|
||||
// The operators in a graph are topologically sorted
|
||||
// and so if no previous operation has produced a tensor
|
||||
@ -745,7 +843,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
builder.getUnitAttr())
|
||||
.getResult();
|
||||
}
|
||||
} else if (nullptr == vals_map.at(input_num)) {
|
||||
} else if (!vals_map.at(input_num)) {
|
||||
auto& const_tensor = *subgraph.tensors[input_num];
|
||||
auto const_loc = TensorLoc(const_tensor, builder, base_loc);
|
||||
auto op_or_err =
|
||||
@ -768,7 +866,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
? base_loc
|
||||
: TensorLoc(*subgraph.tensors[op->outputs[0]], builder, base_loc);
|
||||
// If there's an optional argument, maybe_optional_arg_marker has been set
|
||||
// to a valid Value*
|
||||
// to a valid Value
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto* mlir_op,
|
||||
ConvertOp(*op, vals_map, maybe_optional_arg_marker, op_names,
|
||||
@ -791,9 +889,9 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
|
||||
// Construct return values
|
||||
llvm::SmallVector<Value*, 4> return_operands;
|
||||
llvm::SmallVector<Value, 4> return_operands;
|
||||
for (auto index : func_outputs) {
|
||||
if (nullptr == vals_map.at(index)) {
|
||||
if (!vals_map.at(index)) {
|
||||
auto& const_tensor = *subgraph.tensors[index];
|
||||
auto const_loc = TensorLoc(const_tensor, builder, base_loc);
|
||||
auto op_or_err =
|
||||
@ -837,7 +935,8 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
||||
OwningModuleRef tflite::FlatBufferToMlir(
|
||||
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool use_external_constant) {
|
||||
bool use_external_constant,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
auto model_ptr =
|
||||
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
||||
if (nullptr == model_ptr) {
|
||||
@ -892,7 +991,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||
builder, ordered_output_arrays,
|
||||
/*is_entry_point=*/e.index() == 0,
|
||||
/*use_external_constant=*/use_external_constant);
|
||||
/*use_external_constant=*/use_external_constant,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
if (!func_or_error.ok()) {
|
||||
return emitError(base_loc, "could not translate function ")
|
||||
<< subgraph->name,
|
||||
@ -905,9 +1005,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
return OwningModuleRef(module);
|
||||
}
|
||||
|
||||
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
||||
MLIRContext* context,
|
||||
bool use_external_constant) {
|
||||
static OwningModuleRef FlatBufferFileToMlirTrans(
|
||||
llvm::SourceMgr* source_mgr, MLIRContext* context,
|
||||
bool use_external_constant,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
const llvm::MemoryBuffer* input =
|
||||
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
|
||||
std::string error;
|
||||
@ -924,12 +1025,14 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
||||
|
||||
return tflite::FlatBufferToMlir(
|
||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||
context, loc, outputs, use_external_constant);
|
||||
context, loc, outputs, use_external_constant,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
}
|
||||
|
||||
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
||||
"tflite-flatbuffer-to-mlir",
|
||||
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
|
||||
return FlatBufferFileToMlirTrans(&source_mgr, context,
|
||||
use_external_constant);
|
||||
return FlatBufferFileToMlirTrans(
|
||||
&source_mgr, context, use_external_constant,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
});
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
|
||||
namespace tflite {
|
||||
// Converts a TFLite flatbuffer stored in `buffer` to a MLIR module
|
||||
@ -31,11 +31,14 @@ namespace tflite {
|
||||
// on failure, and more specific errors will be emitted via the context.
|
||||
// If `use_external_constant` is true, it will create `tfl.external_const`
|
||||
// instead of `tfl.const`.
|
||||
// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that
|
||||
// are not ancestors of the output nodes will be pruned.
|
||||
mlir::OwningModuleRef FlatBufferToMlir(
|
||||
absl::string_view buffer, mlir::MLIRContext* context,
|
||||
mlir::Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool use_external_constant = false);
|
||||
bool use_external_constant = false,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally = false);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
|
||||
|
@ -17,15 +17,45 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::errors::InvalidArgument;
|
||||
using ::xla::StatusOr;
|
||||
|
||||
StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
|
||||
mlir::Builder builder,
|
||||
mlir::Location loc) {
|
||||
auto padding = tflite::Padding::Padding_VALID;
|
||||
if (pad_params == TfLitePadding::kTfLitePaddingSame) {
|
||||
padding = tflite::Padding_SAME;
|
||||
} else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
|
||||
padding = tflite::Padding_VALID;
|
||||
} else {
|
||||
return InvalidArgument(
|
||||
absl::StrCat("Invalid padding type", std::to_string(pad_params)));
|
||||
}
|
||||
|
||||
const char* option_name = tflite::EnumNamePadding(padding);
|
||||
return builder.getStringAttr(option_name);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO(jpienaar): This is a placeholder. This should be done in more efficient
|
||||
// way when part of the translation of module.
|
||||
static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
|
||||
@ -212,5 +242,44 @@ static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
|
||||
return builder.getStringAttr(option_name);
|
||||
}
|
||||
|
||||
Status mlir::CustomOptionsToAttributes(
|
||||
const std::string& op_name, const std::vector<uint8_t>& custom_options,
|
||||
mlir::Builder builder, mlir::Location loc,
|
||||
llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
|
||||
if (op_name == "tfl.max_pooling_with_argmax_2d" ||
|
||||
op_name == "tfl.max_unpooling_2d") {
|
||||
auto* pool_params =
|
||||
reinterpret_cast<const TfLitePoolParams*>(custom_options.data());
|
||||
TF_ASSIGN_OR_RETURN(auto padding_attribute,
|
||||
GetPaddingAttr(pool_params->padding, builder, loc));
|
||||
attributes->emplace_back(
|
||||
builder.getNamedAttr("padding", padding_attribute));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"stride_h", builder.getI32IntegerAttr(pool_params->stride_height)));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"filter_w", builder.getI32IntegerAttr(pool_params->filter_height)));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"filter_h", builder.getI32IntegerAttr(pool_params->filter_width)));
|
||||
return Status::OK();
|
||||
|
||||
} else if (op_name == "tfl.convolution_2d_transpose_bias") {
|
||||
auto* conv_params = reinterpret_cast<const TfLiteTransposeConvParams*>(
|
||||
custom_options.data());
|
||||
TF_ASSIGN_OR_RETURN(auto padding_attribute,
|
||||
GetPaddingAttr(conv_params->padding, builder, loc));
|
||||
attributes->emplace_back(
|
||||
builder.getNamedAttr("padding", padding_attribute));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"stride_h", builder.getI32IntegerAttr(conv_params->stride_height)));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"stride_w", builder.getI32IntegerAttr(conv_params->stride_width)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return InvalidArgument(absl::StrCat("invalid custom op type: ", op_name));
|
||||
}
|
||||
|
||||
// Pull in FlatBuffer writers for TFLite generated using TableGen
|
||||
#include "tensorflow/compiler/mlir/lite/operator_converters.inc"
|
||||
|
@ -26,9 +26,10 @@ limitations under the License.
|
||||
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -45,7 +46,7 @@ llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
|
||||
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
|
||||
flatbuffers::FlatBufferBuilder *fbb);
|
||||
|
||||
// Populate the array of mlir::NamedAttributes corresponding to the given
|
||||
// Populates the array of mlir::NamedAttributes corresponding to the given
|
||||
// tflite::FlatbufferOptionsUnion.
|
||||
// We use an out parameter per LLVM convention
|
||||
void BuiltinOptionsToAttributes(
|
||||
@ -53,6 +54,15 @@ void BuiltinOptionsToAttributes(
|
||||
// NOLINTNEXTLINE
|
||||
llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes);
|
||||
|
||||
// Populates the array of mlir::NamedAttributes corresponding to the given
|
||||
// custom_options.
|
||||
// We use an out parameter per LLVM convention
|
||||
tensorflow::Status CustomOptionsToAttributes(
|
||||
const std::string &op_name, const std::vector<uint8_t> &custom_options,
|
||||
mlir::Builder builder,
|
||||
// NOLINTNEXTLINE
|
||||
Location loc, llvm::SmallVectorImpl<mlir::NamedAttribute> *attributes);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_
|
||||
|
@ -41,21 +41,22 @@ limitations under the License.
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Translation.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Translation.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
@ -230,19 +231,19 @@ static bool IsConst(Operation* op) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static bool HasValidTFLiteType(Value* value, T& error_handler) {
|
||||
static bool HasValidTFLiteType(Value value, T& error_handler) {
|
||||
// None type is allowed to represent unspecified operands.
|
||||
if (value->getType().isa<NoneType>()) return true;
|
||||
if (value.getType().isa<NoneType>()) return true;
|
||||
|
||||
auto type = value->getType().dyn_cast<TensorType>();
|
||||
auto type = value.getType().dyn_cast<TensorType>();
|
||||
if (!type) {
|
||||
if (auto op = value->getDefiningOp()) {
|
||||
if (auto op = value.getDefiningOp()) {
|
||||
error_handler.emitError()
|
||||
<< '\'' << op << "' should produce value of tensor type instead of "
|
||||
<< value->getType();
|
||||
<< value.getType();
|
||||
return false;
|
||||
}
|
||||
error_handler.emitError("expected tensor type, got ") << value->getType();
|
||||
error_handler.emitError("expected tensor type, got ") << value.getType();
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -279,9 +280,9 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||
}
|
||||
auto& bb = fn.getBlocks().front();
|
||||
|
||||
for (auto* arg : bb.getArguments()) {
|
||||
for (auto arg : bb.getArguments()) {
|
||||
if (!HasValidTFLiteType(arg, fn))
|
||||
return fn.emitError("invalid TFLite type: ") << arg->getType(), false;
|
||||
return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
|
||||
}
|
||||
|
||||
// Verify that all operations except the terminator have exactly one
|
||||
@ -289,9 +290,9 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||
for (auto& inst : bb) {
|
||||
if (inst.isKnownTerminator()) break;
|
||||
|
||||
for (auto* result : inst.getResults()) {
|
||||
for (auto result : inst.getResults()) {
|
||||
if (!HasValidTFLiteType(result, inst))
|
||||
return fn.emitError("invalid TFLite type: ") << result->getType(),
|
||||
return fn.emitError("invalid TFLite type: ") << result.getType(),
|
||||
false;
|
||||
}
|
||||
}
|
||||
@ -361,7 +362,7 @@ class Translator {
|
||||
|
||||
// Builds TFLite tensor from the given value. `buffer_idx` is index of the
|
||||
// corresponding buffer. Emits error and returns llvm::None on failure.
|
||||
Optional<BufferOffset<tflite::Tensor>> BuildTensor(Value* value,
|
||||
Optional<BufferOffset<tflite::Tensor>> BuildTensor(Value value,
|
||||
const std::string& name,
|
||||
unsigned buffer_idx);
|
||||
|
||||
@ -419,7 +420,7 @@ class Translator {
|
||||
bool IsStatefulOperand(mlir::Operation* op, int operand_index);
|
||||
|
||||
// Returns a unique name for `val`.
|
||||
std::string UniqueName(mlir::Value* val);
|
||||
std::string UniqueName(mlir::Value val);
|
||||
|
||||
ModuleOp module_;
|
||||
|
||||
@ -449,7 +450,7 @@ class Translator {
|
||||
std::vector<std::string> failed_custom_ops_;
|
||||
};
|
||||
|
||||
std::string Translator::UniqueName(mlir::Value* val) {
|
||||
std::string Translator::UniqueName(mlir::Value val) {
|
||||
return name_mapper_.GetUniqueName(val);
|
||||
}
|
||||
|
||||
@ -502,8 +503,8 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
Value* value, const std::string& name, unsigned buffer_idx) {
|
||||
auto type = value->getType().cast<TensorType>();
|
||||
Value value, const std::string& name, unsigned buffer_idx) {
|
||||
auto type = value.getType().cast<TensorType>();
|
||||
|
||||
// TFLite requires tensor shape only for the inputs and constants.
|
||||
// However, we output all known shapes for better round-tripping
|
||||
@ -515,7 +516,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
|
||||
if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
|
||||
return mlir::emitError(
|
||||
value->getLoc(),
|
||||
value.getLoc(),
|
||||
"result shape dimensions out of 32 bit int type range");
|
||||
|
||||
return mlir::success();
|
||||
@ -527,7 +528,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
|
||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
} else if (auto* inst = value->getDefiningOp()) {
|
||||
} else if (auto* inst = value.getDefiningOp()) {
|
||||
if (IsConst(inst)) {
|
||||
// Const op can have a result of dynamic shaped type (e.g. due to constant
|
||||
// folding), but we can still derive the shape of a constant tensor for
|
||||
@ -570,7 +571,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
// marked as a stateful. If so, set the tensor's is_variable as true
|
||||
// This is v1 ref variable semantics in the TFLite runtime.
|
||||
bool is_variable = false;
|
||||
for (auto& use : value->getUses()) {
|
||||
for (auto& use : value.getUses()) {
|
||||
is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
|
||||
if (is_variable) {
|
||||
break;
|
||||
@ -669,6 +670,16 @@ Translator::CreateFlexBuilderWithNodeAttrs(
|
||||
case ::tensorflow::AttrValue::kS:
|
||||
flex_builder->String(key, attr.s());
|
||||
break;
|
||||
case ::tensorflow::AttrValue::kType: {
|
||||
auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type());
|
||||
if (status_or_tfl_type.ok()) {
|
||||
flex_builder->Int(key, status_or_tfl_type.ValueOrDie());
|
||||
} else {
|
||||
emitWarning(loc, "ignoring unsupported tensorflow type: ")
|
||||
<< std::to_string(attr.type());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ::tensorflow::AttrValue::kI:
|
||||
flex_builder->Int(key, attr.i());
|
||||
break;
|
||||
@ -906,13 +917,13 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
bool has_input_attr = false;
|
||||
InitializeNamesFromAttribute(fn, &has_input_attr);
|
||||
std::vector<BufferOffset<tflite::Tensor>> tensors;
|
||||
llvm::DenseMap<Value*, int> tensor_index_map;
|
||||
llvm::DenseMap<Value, int> tensor_index_map;
|
||||
|
||||
// Builds tensor and buffer for argument or operation result. Returns false
|
||||
// on failure.
|
||||
auto build_tensor_and_buffer = [&](Value* value, const std::string& name) {
|
||||
auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
|
||||
// NoneType represents optional and may be skipped here.
|
||||
if (value->getType().isa<NoneType>()) {
|
||||
if (value.getType().isa<NoneType>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -925,7 +936,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
// make the Buffer empty apart from setting the buffer_idx=0 in the Tensor.
|
||||
// This does not seem to affect runtime behavior for RNN/LSTM, but would be
|
||||
// good for reducing memory footprint.
|
||||
if (auto* inst = value->getDefiningOp()) {
|
||||
if (auto* inst = value.getDefiningOp()) {
|
||||
auto buffer_or = BuildBuffer(inst);
|
||||
if (!buffer_or) return false;
|
||||
buffers_.push_back(*buffer_or);
|
||||
@ -942,7 +953,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for
|
||||
// other functions.
|
||||
for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
|
||||
mlir::BlockArgument* arg = bb.getArgument(i);
|
||||
mlir::BlockArgument arg = bb.getArgument(i);
|
||||
std::string name;
|
||||
if (has_input_attr) name = name_mapper_.GetUniqueName(arg);
|
||||
if (name.empty()) name = absl::StrCat("arg", i);
|
||||
@ -964,15 +975,15 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
// Fetch operand and result tensor indices.
|
||||
std::vector<int32_t> operands;
|
||||
operands.reserve(inst.getNumOperands());
|
||||
for (auto* operand : inst.getOperands()) {
|
||||
if (operand->getType().isa<NoneType>())
|
||||
for (auto operand : inst.getOperands()) {
|
||||
if (operand.getType().isa<NoneType>())
|
||||
operands.push_back(kTfLiteOptionalTensor);
|
||||
else
|
||||
operands.push_back(tensor_index_map.lookup(operand));
|
||||
}
|
||||
std::vector<int32_t> results;
|
||||
results.reserve(inst.getNumOperands());
|
||||
for (auto* result : inst.getResults()) {
|
||||
for (auto result : inst.getResults()) {
|
||||
results.push_back(tensor_index_map.lookup(result));
|
||||
}
|
||||
|
||||
@ -986,10 +997,10 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
|
||||
// Get input and output tensor indices for the subgraph.
|
||||
std::vector<int32_t> inputs, outputs;
|
||||
for (auto* arg : bb.getArguments()) {
|
||||
for (auto arg : bb.getArguments()) {
|
||||
inputs.push_back(tensor_index_map[arg]);
|
||||
}
|
||||
for (auto* result : bb.getTerminator()->getOperands()) {
|
||||
for (auto result : bb.getTerminator()->getOperands()) {
|
||||
outputs.push_back(tensor_index_map[result]);
|
||||
}
|
||||
|
||||
|
@ -18,14 +18,15 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Translates the given MLIR `module` into a FlatBuffer and stores the
|
||||
// serialized flatbuffer into the string. This uses OpLocNameMapper to convert
|
||||
// location of the op to name in flatbuffer.
|
||||
// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to
|
||||
// convert location of the op to name in flatbuffer. Returns true if translation
|
||||
// fails, otherwise returns false.
|
||||
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
|
||||
std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops,
|
||||
|
@ -25,17 +25,17 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -301,14 +301,14 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
|
||||
return {};
|
||||
}
|
||||
|
||||
void buildComparisonBinOp(Builder *builder, OperationState &result, Value *lhs,
|
||||
Value *rhs) {
|
||||
void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
|
||||
Value rhs) {
|
||||
auto result_type =
|
||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
||||
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
|
||||
if (!result_type)
|
||||
emitError(result.location)
|
||||
<< "non-broadcastable operands: " << lhs->getType() << " and "
|
||||
<< rhs->getType();
|
||||
<< "non-broadcastable operands: " << lhs.getType() << " and "
|
||||
<< rhs.getType();
|
||||
result.addOperands({lhs, rhs});
|
||||
// Comparison binary ops always return i1 tensor.
|
||||
if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
|
||||
@ -321,15 +321,15 @@ void buildComparisonBinOp(Builder *builder, OperationState &result, Value *lhs,
|
||||
}
|
||||
|
||||
void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
|
||||
Value *lhs, Value *rhs,
|
||||
Value lhs, Value rhs,
|
||||
StringAttr fused_activation_function) {
|
||||
auto result_type =
|
||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
||||
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
|
||||
|
||||
if (!result_type)
|
||||
emitError(result.location)
|
||||
<< "non-broadcastable operands: " << lhs->getType() << " and "
|
||||
<< rhs->getType();
|
||||
<< "non-broadcastable operands: " << lhs.getType() << " and "
|
||||
<< rhs.getType();
|
||||
|
||||
result.addOperands({lhs, rhs});
|
||||
result.addAttribute("fused_activation_function", fused_activation_function);
|
||||
@ -358,7 +358,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
|
||||
namespace {
|
||||
|
||||
int64_t GetConcatenationOpAxis(ConcatenationOp op) {
|
||||
auto output_type = op.output()->getType().cast<RankedTensorType>();
|
||||
auto output_type = op.output().getType().cast<RankedTensorType>();
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
if (axis < 0) axis += output_type.getRank();
|
||||
return axis;
|
||||
@ -452,7 +452,7 @@ LogicalResult VerifyConcatenationOpTypes(Operation *op,
|
||||
}
|
||||
|
||||
LogicalResult Verify(ConcatenationOp op) {
|
||||
auto output_type = op.output()->getType().dyn_cast<RankedTensorType>();
|
||||
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
// If the output type is unranked, there is nothing else to be verified.
|
||||
if (!output_type) return success();
|
||||
@ -462,8 +462,8 @@ LogicalResult Verify(ConcatenationOp op) {
|
||||
return op.emitOpError("concatenation dimension must be in [-rank, rank)");
|
||||
|
||||
SmallVector<TensorType, 4> operand_types;
|
||||
for (Value *operand : op.values())
|
||||
operand_types.push_back(operand->getType().cast<TensorType>());
|
||||
for (Value operand : op.values())
|
||||
operand_types.push_back(operand.getType().cast<TensorType>());
|
||||
|
||||
return VerifyConcatenationOpTypes(op.getOperation(), output_type,
|
||||
operand_types, axis);
|
||||
@ -520,7 +520,7 @@ DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
|
||||
|
||||
OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (fused_activation_function() == "NONE") {
|
||||
if (auto output_type = output()->getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
|
||||
const int64_t axis = GetConcatenationOpAxis(*this);
|
||||
if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
|
||||
return ConstFoldConcatenateOpDense(operands, output_type, axis);
|
||||
@ -528,9 +528,9 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
||||
}
|
||||
|
||||
// Remove all empty values.
|
||||
SmallVector<Value *, 4> non_empty_values;
|
||||
for (Value *value : this->values()) {
|
||||
const auto shaped_type = value->getType().cast<ShapedType>();
|
||||
SmallVector<Value, 4> non_empty_values;
|
||||
for (Value value : this->values()) {
|
||||
const auto shaped_type = value.getType().cast<ShapedType>();
|
||||
if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
|
||||
continue;
|
||||
}
|
||||
@ -559,8 +559,8 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult Verify(FullyConnectedOp op) {
|
||||
ShapedType input_type = op.input()->getType().cast<ShapedType>();
|
||||
ShapedType filter_type = op.filter()->getType().cast<ShapedType>();
|
||||
ShapedType input_type = op.input().getType().cast<ShapedType>();
|
||||
ShapedType filter_type = op.filter().getType().cast<ShapedType>();
|
||||
if (filter_type.hasRank() && filter_type.getRank() != 2) {
|
||||
return op.emitOpError("expect 2d filter, got ") << filter_type;
|
||||
}
|
||||
@ -582,7 +582,7 @@ LogicalResult Verify(FullyConnectedOp op) {
|
||||
// format.
|
||||
if (op.weights_format() == "DEFAULT") {
|
||||
ShapedType output_type =
|
||||
(*op.output().begin())->getType().cast<ShapedType>();
|
||||
(*op.output().begin()).getType().cast<ShapedType>();
|
||||
if (!output_type.hasStaticShape()) {
|
||||
return mlir::success();
|
||||
}
|
||||
@ -609,9 +609,9 @@ LogicalResult Verify(FullyConnectedOp op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void BuildGatherOp(Builder *builder, OperationState &result,
|
||||
Value *params, Value *indices, IntegerAttr axis) {
|
||||
auto params_type = params->getType().cast<TensorType>();
|
||||
auto indices_type = indices->getType().cast<TensorType>();
|
||||
Value params, Value indices, IntegerAttr axis) {
|
||||
auto params_type = params.getType().cast<TensorType>();
|
||||
auto indices_type = indices.getType().cast<TensorType>();
|
||||
|
||||
// If params/indices is unranked, then output is unranked.
|
||||
if (!params_type.hasRank() || !indices_type.hasRank())
|
||||
@ -704,8 +704,8 @@ static LogicalResult Verify(PackOp op) {
|
||||
if (op.getOperation()->getNumOperands() != op.values_count())
|
||||
return op.emitOpError("input count should match 'values_count' attribute");
|
||||
|
||||
Value *operand0 = op.getOperand(0);
|
||||
auto input_type = operand0->getType().cast<ShapedType>();
|
||||
Value operand0 = op.getOperand(0);
|
||||
auto input_type = operand0.getType().cast<ShapedType>();
|
||||
|
||||
// Check axis bounds.
|
||||
if (input_type.hasRank()) {
|
||||
@ -717,8 +717,8 @@ static LogicalResult Verify(PackOp op) {
|
||||
|
||||
// Make sure all inputs have the same shape and element type.
|
||||
// TODO(rahulsp): Simplify once b/135032064 is fixed.
|
||||
for (Value *operand : op.getOperands()) {
|
||||
auto other_type = operand->getType().cast<ShapedType>();
|
||||
for (Value operand : op.getOperands()) {
|
||||
auto other_type = operand.getType().cast<ShapedType>();
|
||||
if (input_type != other_type)
|
||||
return op.emitOpError("operands should be of the same type. got ")
|
||||
<< input_type << ", " << other_type;
|
||||
@ -732,9 +732,9 @@ static LogicalResult Verify(PackOp op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(PReluOp op) {
|
||||
auto input_type = op.input()->getType().cast<ShapedType>();
|
||||
auto alpha_type = op.alpha()->getType().cast<ShapedType>();
|
||||
auto output_type = op.output()->getType().cast<ShapedType>();
|
||||
auto input_type = op.input().getType().cast<ShapedType>();
|
||||
auto alpha_type = op.alpha().getType().cast<ShapedType>();
|
||||
auto output_type = op.output().getType().cast<ShapedType>();
|
||||
|
||||
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
|
||||
if (input_type.getRank() != alpha_type.getRank() + 1) {
|
||||
@ -783,13 +783,13 @@ struct RemoveAdjacentReshape : public RewritePattern {
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
auto thisOp = cast<ReshapeOp>(op);
|
||||
auto prevOp = thisOp.getOperand(0)->getDefiningOp();
|
||||
auto prevOp = thisOp.getOperand(0).getDefiningOp();
|
||||
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
auto thisOp = cast<ReshapeOp>(op);
|
||||
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0)->getDefiningOp());
|
||||
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
|
||||
|
||||
// Replace
|
||||
// %1 = "tfl.reshape"(%0, %shape0)
|
||||
@ -807,7 +807,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
|
||||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Remove identity reshape with both static result and input shape.
|
||||
auto result_type = getType().cast<ShapedType>();
|
||||
auto input_type = getOperand(0)->getType().cast<ShapedType>();
|
||||
auto input_type = getOperand(0).getType().cast<ShapedType>();
|
||||
if (result_type.hasStaticShape() && result_type == input_type) {
|
||||
return getOperand(0);
|
||||
}
|
||||
@ -865,7 +865,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TFL::PackOp pack_op = cast<TFL::PackOp>(op);
|
||||
Operation *first_input = pack_op.getOperand(0)->getDefiningOp();
|
||||
Operation *first_input = pack_op.getOperand(0).getDefiningOp();
|
||||
if (!first_input) return matchFailure();
|
||||
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
|
||||
if (!input_unpack_op) return matchFailure();
|
||||
@ -880,8 +880,8 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
|
||||
return matchFailure();
|
||||
for (auto input_output :
|
||||
llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
|
||||
Value *pack_input = std::get<0>(input_output);
|
||||
Value *unpack_output = std::get<1>(input_output);
|
||||
Value pack_input = std::get<0>(input_output);
|
||||
Value unpack_output = std::get<1>(input_output);
|
||||
// Make sure the ordering is the same for the pack op & unpack op.
|
||||
if (pack_input != unpack_output) return matchFailure();
|
||||
}
|
||||
@ -905,9 +905,9 @@ void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(SliceOp op) {
|
||||
auto input_type = op.input()->getType().cast<ShapedType>();
|
||||
auto begin_type = op.begin()->getType().cast<ShapedType>();
|
||||
auto size_type = op.size()->getType().cast<ShapedType>();
|
||||
auto input_type = op.input().getType().cast<ShapedType>();
|
||||
auto begin_type = op.begin().getType().cast<ShapedType>();
|
||||
auto size_type = op.size().getType().cast<ShapedType>();
|
||||
if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
|
||||
size_type.hasStaticShape()) {
|
||||
if (input_type.getRank() != begin_type.getNumElements()) {
|
||||
@ -984,8 +984,8 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
|
||||
// TopKOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void BuildTopKOp(Builder *builder, OperationState &result, Value *input,
|
||||
Value *k) {
|
||||
static void BuildTopKOp(Builder *builder, OperationState &result, Value input,
|
||||
Value k) {
|
||||
// Output size is only known if k is constant value. A negative dimension is
|
||||
// considered dynamic so use -1 here if k is not a constant value.
|
||||
int const_k = -1;
|
||||
@ -995,7 +995,7 @@ static void BuildTopKOp(Builder *builder, OperationState &result, Value *input,
|
||||
// TODO(jpienaar): This should use a helper function.
|
||||
const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
|
||||
|
||||
auto val_type = input->getType().cast<TensorType>();
|
||||
auto val_type = input.getType().cast<TensorType>();
|
||||
// If value is unranked, then so is results.
|
||||
if (!val_type.hasRank())
|
||||
return TFL::TopKV2Op::build(
|
||||
@ -1035,7 +1035,7 @@ struct DropFakeQuant : public RewritePattern {
|
||||
// If all the users of this op have valid "minmax" attributes, it is matched
|
||||
// and can be removed.
|
||||
auto fakeQuantOp = cast<FakeQuantOp>(op);
|
||||
for (auto *operand : fakeQuantOp.getResult()->getUsers())
|
||||
for (auto *operand : fakeQuantOp.getResult().getUsers())
|
||||
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
|
||||
|
||||
return matchSuccess();
|
||||
@ -1075,7 +1075,7 @@ static LogicalResult Verify(UnpackOp op) {
|
||||
|
||||
// Extracts and returns the signed integer constant in a 0-rank integer tensor
|
||||
// or 1-element 1-rank integer tensor if 'value' is a constant.
|
||||
static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value *value) {
|
||||
static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value value) {
|
||||
ElementsAttr attr;
|
||||
if (!matchPattern(value, m_Constant(&attr))) return {};
|
||||
if (attr.getNumElements() != 1) return {};
|
||||
@ -1101,8 +1101,8 @@ static LogicalResult VerifySplitOpOutputTypes(
|
||||
ExpectedOutputTypeGetter get_expected_output_type) {
|
||||
for (int64_t i = 0; i < num_splits; ++i) {
|
||||
auto expected_output_type = get_expected_output_type(i);
|
||||
Value *output = op->getResult(i);
|
||||
auto output_type = output->getType().dyn_cast<RankedTensorType>();
|
||||
Value output = op->getResult(i);
|
||||
auto output_type = output.getType().dyn_cast<RankedTensorType>();
|
||||
if (!output_type || output_type != expected_output_type)
|
||||
return op->emitOpError()
|
||||
<< "output #" << i << " should be " << expected_output_type;
|
||||
@ -1121,7 +1121,7 @@ static LogicalResult Verify(SplitOp op) {
|
||||
if (!split_dim_opt) return success();
|
||||
|
||||
// If 'input' is not a ranked tensor, there are no other checks.
|
||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
||||
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type) return success();
|
||||
|
||||
int64_t split_dim = split_dim_opt.getValue();
|
||||
@ -1157,7 +1157,7 @@ static LogicalResult Verify(SplitVOp op) {
|
||||
if (!split_dim_opt) return success();
|
||||
|
||||
// If 'input' is not a ranked tensor, there are no other checks.
|
||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
||||
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type) return success();
|
||||
|
||||
int64_t split_dim = split_dim_opt.getValue();
|
||||
@ -1177,8 +1177,7 @@ static LogicalResult Verify(SplitVOp op) {
|
||||
return success();
|
||||
|
||||
if (size_splits_attr.getNumElements() != num_splits) {
|
||||
auto size_splits_type =
|
||||
op.size_splits()->getType().cast<RankedTensorType>();
|
||||
auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
|
||||
RankedTensorType expected_size_splits_type =
|
||||
RankedTensorType::get({num_splits}, size_splits_type.getElementType());
|
||||
return op.emitOpError("'size_splits' should be ")
|
||||
@ -1414,7 +1413,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
}
|
||||
|
||||
// Also fold if `input` has a known rank.
|
||||
auto input_type = input()->getType().cast<ShapedType>();
|
||||
auto input_type = input().getType().cast<ShapedType>();
|
||||
// Do not fold if rank is zero because the TFLite converter doesn't
|
||||
// distinguish between unranked input and scalar input due to b/138865275.
|
||||
// TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
|
||||
@ -1438,6 +1437,56 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||
return value();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SelectV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void BuildSelectV2Op(Builder *builder, OperationState &result,
|
||||
Value cond, Value x, Value y) {
|
||||
auto operand_type =
|
||||
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
|
||||
|
||||
if (!operand_type)
|
||||
emitError(result.location) << "non-broadcastable operands: " << x.getType()
|
||||
<< " and " << y.getType();
|
||||
|
||||
bool has_static_cond_shape = false;
|
||||
bool has_static_operand_shape = false;
|
||||
ArrayRef<int64_t> cond_shape;
|
||||
ArrayRef<int64_t> operand_shape;
|
||||
|
||||
if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
|
||||
if (shaped_type.hasStaticShape()) {
|
||||
has_static_cond_shape = true;
|
||||
cond_shape = shaped_type.getShape();
|
||||
}
|
||||
}
|
||||
if (auto shaped_type = operand_type.dyn_cast<ShapedType>()) {
|
||||
if (shaped_type.hasStaticShape()) {
|
||||
has_static_operand_shape = true;
|
||||
operand_shape = shaped_type.getShape();
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> broadcastedShape;
|
||||
if (has_static_cond_shape && has_static_operand_shape &&
|
||||
!OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
|
||||
broadcastedShape)) {
|
||||
emitError(result.location) << "non-broadcastable operands: " << operand_type
|
||||
<< " and " << cond.getType();
|
||||
}
|
||||
|
||||
result.addOperands({cond, x, y});
|
||||
|
||||
auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
|
||||
if (has_static_cond_shape && has_static_operand_shape) {
|
||||
result.types.push_back(
|
||||
RankedTensorType::get(broadcastedShape, elementType));
|
||||
} else {
|
||||
result.types.push_back(UnrankedTensorType::get(elementType));
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1521,9 +1570,8 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(TransposeConvOp op) {
|
||||
ShapedType output_type = op.output()->getType().cast<ShapedType>();
|
||||
ShapedType output_shape_type =
|
||||
op.output_shape()->getType().cast<ShapedType>();
|
||||
ShapedType output_type = op.output().getType().cast<ShapedType>();
|
||||
ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
|
||||
if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
|
||||
if (output_type.getRank() != output_shape_type.getDimSize(0)) {
|
||||
return op.emitOpError(llvm::formatv(
|
||||
@ -1629,9 +1677,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
}
|
||||
|
||||
static LogicalResult Verify(TransposeOp op) {
|
||||
auto input_type = op.x()->getType().cast<ShapedType>();
|
||||
auto perm_type = op.perm()->getType().cast<ShapedType>();
|
||||
auto output_type = op.y()->getType().cast<ShapedType>();
|
||||
auto input_type = op.x().getType().cast<ShapedType>();
|
||||
auto perm_type = op.perm().getType().cast<ShapedType>();
|
||||
auto output_type = op.y().getType().cast<ShapedType>();
|
||||
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
|
||||
if (perm_type.getNumElements() != input_type.getRank()) {
|
||||
return op.emitOpError(
|
||||
|
@ -18,15 +18,15 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
|
||||
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Dialect.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
@ -135,7 +135,7 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TFL_OperandIsUnrankedPred<int n> :
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">;
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
|
||||
|
||||
// TODO: Some of these could be generalized and/or moved to more general
|
||||
// location.
|
||||
@ -144,38 +144,38 @@ class TFL_OperandHasRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||
").getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||
|
||||
// Returns true if the n-th operand is ranked and has rank dim.
|
||||
class TFL_OperandHasKnownRank<int n, int dim> : And<[
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() == "
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() == "
|
||||
# dim>]>;
|
||||
|
||||
// True if operand n is ranked and has a rank > dim.
|
||||
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > "
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
|
||||
# dim>]>;
|
||||
|
||||
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
|
||||
TFL_OperandIsRankedAndHasDimPred<n, dim>,
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()"
|
||||
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
|
||||
".getShape()[" # dim # " ] == " # size>]>;
|
||||
|
||||
// Returns true if the n-th operand has unknown rank or at least rank m.
|
||||
class TFL_OperandHasAtleastRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
|
||||
Or<[CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
")->getType().cast<ShapedType>().getRank() >= " # m>]>>;
|
||||
").getType().cast<ShapedType>().getRank() >= " # m>]>>;
|
||||
|
||||
class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
|
||||
PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
|
||||
CPred<"$_op.getOperand(" # x #
|
||||
")->getType().cast<ShapedType>().getRank() == "
|
||||
").getType().cast<ShapedType>().getRank() == "
|
||||
"$_op.getOperand(" # y #
|
||||
")->getType().cast<ShapedType>().getShape()[0]">>;
|
||||
").getType().cast<ShapedType>().getShape()[0]">>;
|
||||
|
||||
class TFL_Operand0DOr1ElementTensor<int x> :
|
||||
PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element",
|
||||
@ -195,7 +195,7 @@ class TFL_OperandHasRankLessThan<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is maximum " # m # "-D",
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
")->getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||
|
||||
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
||||
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
||||
@ -224,10 +224,10 @@ def BinaryOpSameElementTypeConstraint :
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TFL_BroadcastableBinaryBuilder : OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value *lhs, Value *rhs",
|
||||
"Builder *builder, OperationState &result, Value lhs, Value rhs",
|
||||
[{
|
||||
auto resultType =
|
||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
||||
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
|
||||
if (!resultType)
|
||||
mlir::emitError(result.location, "non-broadcastable operands");
|
||||
result.addOperands({lhs, rhs});
|
||||
@ -235,7 +235,7 @@ def TFL_BroadcastableBinaryBuilder : OpBuilder<
|
||||
}]>;
|
||||
|
||||
def TFL_FusedBroadcastableBinaryBuilder : OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value *lhs, Value *rhs, "
|
||||
"Builder *builder, OperationState &result, Value lhs, Value rhs, "
|
||||
"StringAttr fusedActivationFunction",
|
||||
[{
|
||||
buildFusedBroadcastableBinOp(
|
||||
@ -243,7 +243,7 @@ def TFL_FusedBroadcastableBinaryBuilder : OpBuilder<
|
||||
}]>;
|
||||
|
||||
def TFL_ComparisonBinaryBuilder : OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value *lhs, Value *rhs",
|
||||
"Builder *builder, OperationState &result, Value lhs, Value rhs",
|
||||
[{
|
||||
buildComparisonBinOp(builder, result, lhs, rhs);
|
||||
}]>;
|
||||
@ -427,6 +427,33 @@ def TFL_TransposeConvOp:
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TFL_Convolution2DTransposeBiasOp :
|
||||
Op<TFL_Dialect, "convolution_2d_transpose_bias", [NoSideEffect]> {
|
||||
let summary = " Transpose convolution with bias operator";
|
||||
|
||||
let description = [{
|
||||
Performs transpose convolution operation on inputs,
|
||||
with the option of adding a bias.
|
||||
Note this is a custom op that is not supported in the standard runtime.
|
||||
|
||||
Inputs:
|
||||
`inputs[0]`: required: the input activation tensor
|
||||
`inputs[1]`: required: the filter weight tensor
|
||||
`inputs[2]`: optional: the bias tensor
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
AnyTensor:$filter,
|
||||
TFL_TensorOfOrNone<[AnyType]>:$bias,
|
||||
TFL_PaddingAttr:$padding,
|
||||
I32Attr:$stride_h,
|
||||
I32Attr:$stride_w
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_AveragePool2DOp:
|
||||
TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Average_pool_2d operator";
|
||||
@ -471,7 +498,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
|
||||
let hasOptions = 1;
|
||||
|
||||
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
|
||||
return getResult()->getType().cast<TensorType>().getElementType().
|
||||
return getResult().getType().cast<TensorType>().getElementType().
|
||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||
tflite::TensorType_INT32;
|
||||
}]>;
|
||||
@ -500,7 +527,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
|
||||
let hasOptions = 1;
|
||||
|
||||
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
|
||||
return getResult()->getType().cast<TensorType>().getElementType().
|
||||
return getResult().getType().cast<TensorType>().getElementType().
|
||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||
tflite::TensorType_INT32;
|
||||
}]>;
|
||||
@ -669,7 +696,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
let builders =
|
||||
[
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"Value *params, Value *indices, IntegerAttr axis",
|
||||
"Value params, Value indices, IntegerAttr axis",
|
||||
[{ BuildGatherOp(builder, result, params, indices, axis); }]>
|
||||
];
|
||||
|
||||
@ -932,7 +959,7 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
let builders =
|
||||
[
|
||||
OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value *lhs, Value *rhs",
|
||||
"Builder *builder, OperationState &result, Value lhs, Value rhs",
|
||||
[{
|
||||
buildComparisonBinOp(builder, result, lhs, rhs);
|
||||
}]>
|
||||
@ -1427,6 +1454,63 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
let customOption = "Pool2DOptions";
|
||||
}
|
||||
|
||||
def TFL_MaxPoolingWithArgMax2DOp :
|
||||
Op<TFL_Dialect, "max_pooling_with_argmax_2d", [NoSideEffect]> {
|
||||
let summary = "Max Pool 2D with argmax op";
|
||||
|
||||
let description = [{
|
||||
Performs max pooling on the input and outputs both max values and indices.
|
||||
Each index is a flatten index in a sub-array of "filter_w" x "filter_h" size
|
||||
Note this is a custom op that is not supported in the standard runtime.
|
||||
|
||||
Inputs:
|
||||
`inputs[0]`: required: the input activation tensor
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
TFL_PaddingAttr:$padding,
|
||||
I32Attr:$stride_w,
|
||||
I32Attr:$stride_h,
|
||||
I32Attr:$filter_w,
|
||||
I32Attr:$filter_h
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
AnyTensor:$value,
|
||||
AnyTensor:$indices
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_MaxUnpooling2DOp :
|
||||
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect]> {
|
||||
let summary = "Max Unpool 2D";
|
||||
|
||||
let description = [{
|
||||
Performs max unpool operation.
|
||||
To some extent this is the reverse operation of max pooling:
|
||||
the elements in the input activation tensor is stored into the position
|
||||
specified by the input indices.
|
||||
Note this is a custom op that is not supported in the standard runtime.
|
||||
|
||||
Inputs:
|
||||
`inputs[0]`: required: the input activation tensor
|
||||
`inputs[1]`: required: the input indices
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
AnyTensor:$indices,
|
||||
TFL_PaddingAttr:$padding,
|
||||
I32Attr:$stride_w,
|
||||
I32Attr:$stride_h,
|
||||
I32Attr:$filter_w,
|
||||
I32Attr:$filter_h
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$outputs);
|
||||
}
|
||||
|
||||
def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
|
||||
@ -1996,7 +2080,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
DerivedTypeAttr out_type = DerivedTypeAttr<[{
|
||||
return getResult()->getType().cast<TensorType>().getElementType();
|
||||
return getResult().getType().cast<TensorType>().getElementType();
|
||||
}]>;
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -2081,9 +2165,9 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
|
||||
|
||||
// TODO(jpienaar): autogenerate this.
|
||||
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"Value *condition, Value *x, Value *y",
|
||||
"Value condition, Value x, Value y",
|
||||
[{
|
||||
auto resultType = x->getType();
|
||||
auto resultType = x.getType();
|
||||
result.addOperands({condition, x, y});
|
||||
result.types.push_back(resultType);
|
||||
}]>];
|
||||
@ -2091,6 +2175,32 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> {
|
||||
let summary = "SelectV2 operator";
|
||||
|
||||
let description = [{
|
||||
Select values of 'x' if the corresponding value of 'condition' is true or
|
||||
the value of 'y' if false. There are valid condition input sizes:
|
||||
|
||||
1. Either the same shape (in which case the select is elementwise), or
|
||||
2. Broadcastable shapes between 'condition', 'x' and 'y'.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_BoolTensor:$condition,
|
||||
TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x,
|
||||
TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y);
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"Value cond, Value x, Value y",
|
||||
[{
|
||||
BuildSelectV2Op(builder, result, cond, x, y);
|
||||
}]>];
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_SinOp: TFL_Op<"sin", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
let summary = "Sine operator";
|
||||
@ -2277,7 +2387,7 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
|
||||
I32Tensor:$indices);
|
||||
|
||||
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"Value *input, Value *k",
|
||||
"Value input, Value k",
|
||||
[{ BuildTopKOp(builder, result, input, k); }]>];
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -2333,14 +2443,14 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I32, QI8, QUI8]>:$input,
|
||||
TensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$input,
|
||||
|
||||
I32Attr:$num,
|
||||
I32Attr:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TensorOf<[F32, I8, I32, QI8, QUI8]>>:$outputs
|
||||
Variadic<TensorOf<[F32, I1, I8, I32, QI8, QUI8]>>:$outputs
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -2707,7 +2817,7 @@ in the unique output `y`. In other words:
|
||||
);
|
||||
|
||||
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
|
||||
return getResult(1)->getType().cast<TensorType>().getElementType().
|
||||
return getResult(1).getType().cast<TensorType>().getElementType().
|
||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||
tflite::TensorType_INT32;
|
||||
}]>;
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
@ -30,10 +30,10 @@ limitations under the License.
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Parser.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Parser.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "llvm/TableGen/Main.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Attribute.h" // TF:local_config_mlir
|
||||
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
|
||||
|
||||
using llvm::DefInit;
|
||||
using llvm::dyn_cast;
|
||||
|
@ -28,10 +28,10 @@ cc_library(
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:ViewOpGraph",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:ViewOpGraph",
|
||||
],
|
||||
)
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
@ -151,10 +151,9 @@ Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
|
||||
return errors::InvalidArgument("fail to parse extra OpDef");
|
||||
}
|
||||
// Make sure the op is not already registered. If registered continue.
|
||||
const OpRegistrationData* op_reg = nullptr;
|
||||
auto status =
|
||||
tensorflow::OpRegistry::Global()->LookUp(opdef.name(), &op_reg);
|
||||
if (status.ok()) continue;
|
||||
const OpRegistrationData* op_reg =
|
||||
tensorflow::OpRegistry::Global()->LookUp(opdef.name());
|
||||
if (op_reg) continue;
|
||||
|
||||
tensorflow::OpRegistry::Global()->Register(
|
||||
[opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
|
||||
@ -278,7 +277,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
|
||||
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
// rename once we enable the new converter feature flag.
|
||||
|
@ -13,7 +13,7 @@ package(
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["@local_config_mlir//:subpackages"],
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = ["//tensorflow/compiler/mlir/..."],
|
||||
)
|
||||
|
||||
@ -26,8 +26,8 @@ filegroup(
|
||||
name = "quantization_td_files",
|
||||
srcs = [
|
||||
"quantization.td",
|
||||
"@local_config_mlir//:OpBaseTdFiles",
|
||||
"@local_config_mlir//:QuantizationOpsTdFiles",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:QuantizationOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
@ -53,13 +53,13 @@ cc_library(
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -75,11 +75,11 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(fengliuai): remove this dependence.
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
@ -97,7 +97,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -107,8 +107,8 @@ tf_native_cc_binary(
|
||||
"tools/op_quant_spec_getters_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
"@llvm//:tablegen",
|
||||
"@local_config_mlir//:TableGen",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
)
|
||||
|
@ -23,18 +23,18 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/AffineExpr.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/AffineMap.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineMap.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
@ -70,16 +70,16 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
|
||||
void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
|
||||
const QuantParamsEntry &info);
|
||||
|
||||
void InsertStatsOpAtResult(OpBuilder b, Value *res, ElementsAttr layer_stats,
|
||||
void InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats,
|
||||
ElementsAttr axis_stats, IntegerAttr axis);
|
||||
|
||||
// If the index is out of range, this method returns false. Otherwise it
|
||||
// returns true if the value is a float tensor.
|
||||
bool IsQuantizableResult(Operation *op, int index) {
|
||||
if (index < 0 || index >= op->getNumResults()) return false;
|
||||
Value *res = op->getResult(index);
|
||||
return res->getType().isa<ShapedType>() &&
|
||||
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
||||
Value res = op->getResult(index);
|
||||
return res.getType().isa<ShapedType>() &&
|
||||
res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
||||
}
|
||||
|
||||
// A method to retrieve the name for the given op.
|
||||
@ -117,13 +117,13 @@ bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) {
|
||||
return false;
|
||||
}
|
||||
|
||||
void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value *res,
|
||||
void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
|
||||
ElementsAttr layer_stats,
|
||||
ElementsAttr axis_stats,
|
||||
IntegerAttr axis) {
|
||||
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
|
||||
layer_stats, axis_stats, axis);
|
||||
res->replaceAllUsesWith(stats_op);
|
||||
res.replaceAllUsesWith(stats_op);
|
||||
stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,7 @@ package(
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["@local_config_mlir//:subpackages"],
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//tensorflow/lite/...",
|
||||
@ -36,9 +36,9 @@ cc_library(
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
)
|
||||
|
||||
@ -53,6 +53,6 @@ tf_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
@ -17,11 +17,11 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
|
@ -23,17 +23,17 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
@ -146,14 +146,14 @@ class QuantizationDriver {
|
||||
|
||||
// Adds all the users of index-th result of op to the work list.
|
||||
void AddUserToList(Operation *op, int index) {
|
||||
for (auto *user : op->getResult(index)->getUsers()) {
|
||||
for (auto *user : op->getResult(index).getUsers()) {
|
||||
work_list_.push_back(user);
|
||||
}
|
||||
}
|
||||
|
||||
// Adds the defining op of index-th operand of op to the work list.
|
||||
void AddOperandToList(Operation *op, int index) {
|
||||
if (auto *inst = op->getOperand(index)->getDefiningOp()) {
|
||||
if (auto *inst = op->getOperand(index).getDefiningOp()) {
|
||||
work_list_.push_back(inst);
|
||||
}
|
||||
}
|
||||
@ -183,20 +183,20 @@ class QuantizationDriver {
|
||||
// of the op.
|
||||
void QuantizeOpResult(Operation *op, int index, QuantParams params);
|
||||
|
||||
void QuantizeArg(BlockArgument *arg, QuantParams params);
|
||||
void QuantizeArg(BlockArgument arg, QuantParams params);
|
||||
|
||||
// Inserts the Quantize and Dequantize ops to quantize the value and returns
|
||||
// the Quantize op.
|
||||
void QuantizeValue(Value *value, QuantParams params, Location loc);
|
||||
void QuantizeValue(Value value, QuantParams params, Location loc);
|
||||
|
||||
// Inserts the Quantize ops for requantizing the index-th result of the op.
|
||||
void RequantizeOpResult(Operation *op, int index, RequantizeState *state);
|
||||
|
||||
void RequantizeArg(BlockArgument *arg, RequantizeState *state);
|
||||
void RequantizeArg(BlockArgument arg, RequantizeState *state);
|
||||
|
||||
// Inserts the Quantize and Dequantize ops to quantize the value and returns
|
||||
// the Quantize op.
|
||||
void RequantizeValue(Value *value, RequantizeState *state, Location loc);
|
||||
void RequantizeValue(Value value, RequantizeState *state, Location loc);
|
||||
|
||||
// A heuristic to get the quantization parameter satisfies the same scale
|
||||
// constraints for the op. Returns an empty option if this quantization
|
||||
@ -213,7 +213,7 @@ class QuantizationDriver {
|
||||
return states_[result_states_[{op, index}]];
|
||||
}
|
||||
|
||||
QuantState &GetArgQuantState(BlockArgument *arg) {
|
||||
QuantState &GetArgQuantState(BlockArgument arg) {
|
||||
return states_[arg_states_[arg]];
|
||||
}
|
||||
|
||||
@ -227,7 +227,7 @@ class QuantizationDriver {
|
||||
return rescale_states_[result_states_[{op, index}]];
|
||||
}
|
||||
|
||||
RequantizeState &GetArgRequantizeState(BlockArgument *arg) {
|
||||
RequantizeState &GetArgRequantizeState(BlockArgument arg) {
|
||||
return rescale_states_[arg_states_[arg]];
|
||||
}
|
||||
|
||||
@ -235,32 +235,45 @@ class QuantizationDriver {
|
||||
// `as_result` is true or index-th operand if `as_result` is false. The state
|
||||
// is immutable if the type is a quantized type. Returns the index of this
|
||||
// new state in the state vector.
|
||||
int InitializeState(Operation *op, int index, Value *val, bool as_result);
|
||||
int InitializeState(Operation *op, int index, Value val, bool as_result);
|
||||
|
||||
// Sets the state of an argument. If this value is cached, uses the cached
|
||||
// result without creating new entry in the state vector. Otherwise, allocate
|
||||
// a new entry in the state vector.
|
||||
void InitializeArgState(BlockArgument arg, Value in,
|
||||
llvm::DenseMap<Value, int> *cache) {
|
||||
auto cached = cache->insert({in, 0});
|
||||
if (!cached.second) {
|
||||
arg_states_[arg] = cached.first->second;
|
||||
return;
|
||||
}
|
||||
QuantParams params =
|
||||
quant::QuantizedType::getQuantizedElementType(in.getType());
|
||||
bool immutable = !EmptyParams(params);
|
||||
int next_state_index = states_.size();
|
||||
states_.push_back({params, immutable});
|
||||
arg_states_[arg] = next_state_index;
|
||||
cached.first->second = next_state_index;
|
||||
}
|
||||
|
||||
// Sets the state of the index-th operand of the op. If this operand is
|
||||
// cached, uses the cached result without creating new entry in the state
|
||||
// vector. Otherwise, allocate a new entry in the state vector.
|
||||
void InitializeOperandState(Operation *op, int index, Value *in,
|
||||
llvm::DenseMap<Value *, int> *cache,
|
||||
bool is_argument) {
|
||||
void InitializeOperandState(Operation *op, int index, Value in,
|
||||
llvm::DenseMap<Value, int> *cache) {
|
||||
auto cached = cache->insert({in, 0});
|
||||
if (!cached.second) {
|
||||
operand_states_.insert({{op, index}, cached.first->second});
|
||||
return;
|
||||
}
|
||||
cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
|
||||
if (is_argument) {
|
||||
auto *arg = llvm::cast<BlockArgument>(in);
|
||||
arg_states_[arg] = cached.first->second;
|
||||
args_.push_back(arg);
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the state of the index-th result of the op. If this result is cached,
|
||||
// uses the cached result without creating new entry in the state vector.
|
||||
// Otherwise, allocate a new entry in the state vector.
|
||||
void InitializeResultState(Operation *op, int index, Value *res,
|
||||
llvm::DenseMap<Value *, int> *cache) {
|
||||
void InitializeResultState(Operation *op, int index, Value res,
|
||||
llvm::DenseMap<Value, int> *cache) {
|
||||
auto cached = cache->insert({res, 0});
|
||||
if (!cached.second) {
|
||||
result_states_.insert({{op, index}, cached.first->second});
|
||||
@ -279,7 +292,8 @@ class QuantizationDriver {
|
||||
// rest are weights.
|
||||
llvm::DenseSet<Operation *> weights_;
|
||||
|
||||
// The weights require narrow_range quantization. If the value of this map is
|
||||
// The weights require narrow_range quantization. This map collects all the
|
||||
// weight operands defined by the op quant spec. If the value of the entry is
|
||||
// positive, per-channel quantization is required.
|
||||
llvm::DenseMap<Operation *, int> optimized_weights_;
|
||||
|
||||
@ -300,11 +314,11 @@ class QuantizationDriver {
|
||||
// results and arguments.
|
||||
llvm::DenseMap<OpValue, int> operand_states_;
|
||||
llvm::DenseMap<OpValue, int> result_states_;
|
||||
llvm::DenseMap<BlockArgument *, int> arg_states_;
|
||||
llvm::DenseMap<BlockArgument, int> arg_states_;
|
||||
|
||||
// This vector is to preserve the arguments order, so the newly inserted
|
||||
// quantized ops for the arguments are deterministically ordered.
|
||||
llvm::SmallVector<BlockArgument *, 4> args_;
|
||||
llvm::SmallVector<BlockArgument, 4> args_;
|
||||
|
||||
OpQuantSpecGetter op_quant_spec_getter_;
|
||||
};
|
||||
@ -321,10 +335,10 @@ bool QuantizationDriver::IsQuantized(Operation *op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int QuantizationDriver::InitializeState(Operation *op, int index, Value *val,
|
||||
int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
|
||||
bool as_result) {
|
||||
QuantParams params =
|
||||
quant::QuantizedType::getQuantizedElementType(val->getType());
|
||||
quant::QuantizedType::getQuantizedElementType(val.getType());
|
||||
bool immutable = !EmptyParams(params);
|
||||
int next_state_index = states_.size();
|
||||
states_.push_back({params, immutable});
|
||||
@ -338,7 +352,7 @@ int QuantizationDriver::InitializeState(Operation *op, int index, Value *val,
|
||||
|
||||
bool QuantizationDriver::SetConstantResultParams(Operation *op) {
|
||||
ElementsAttr attr;
|
||||
Value *res = op->getResult(0);
|
||||
Value res = op->getResult(0);
|
||||
if (!matchPattern(res, m_Constant(&attr))) {
|
||||
return false;
|
||||
}
|
||||
@ -362,7 +376,7 @@ bool QuantizationDriver::SetConstantResultParams(Operation *op) {
|
||||
} else {
|
||||
// per-tensor quantization weight
|
||||
final_type = GetUniformQuantizedTypeForWeight(
|
||||
attr, /*symmetric=*/is_weight_with_per_channel_support,
|
||||
attr, /*symmetric=*/is_weight && is_signed_,
|
||||
/*num_bits=*/8, is_signed_,
|
||||
/*narrow_range_=*/is_weight);
|
||||
}
|
||||
@ -428,18 +442,18 @@ bool QuantizationDriver::SetOperandParams(Operation *op, int index,
|
||||
void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
|
||||
QuantParams params) {
|
||||
builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
|
||||
Value *original_result = op->getResult(index);
|
||||
Value original_result = op->getResult(index);
|
||||
QuantizeValue(original_result, params, op->getLoc());
|
||||
}
|
||||
|
||||
void QuantizationDriver::QuantizeArg(BlockArgument *arg, QuantParams params) {
|
||||
builder_.setInsertionPointToStart(arg->getOwner());
|
||||
void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
|
||||
builder_.setInsertionPointToStart(arg.getOwner());
|
||||
QuantizeValue(arg, params, builder_.getUnknownLoc());
|
||||
}
|
||||
|
||||
void QuantizationDriver::QuantizeValue(Value *value, QuantParams params,
|
||||
void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
||||
Location loc) {
|
||||
Type expressed_type = value->getType();
|
||||
Type expressed_type = value.getType();
|
||||
Type new_type = params.castFromExpressedType(expressed_type);
|
||||
// This value isn't an expressed type (float), skip.
|
||||
if (!new_type) return;
|
||||
@ -451,7 +465,7 @@ void QuantizationDriver::QuantizeValue(Value *value, QuantParams params,
|
||||
quantize.output());
|
||||
// `original_result` has a use to `quantize`, so this will replace that use
|
||||
// by the result of `dequantize`. Remember to reset that use afterwards
|
||||
value->replaceAllUsesWith(dequantize);
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
}
|
||||
|
||||
@ -459,9 +473,9 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
|
||||
RequantizeState *state) {
|
||||
if (state->pos == RequantizeState::NO_REQUANTIZE) return;
|
||||
builder_.setInsertionPointAfter(op);
|
||||
Value *value = op->getResult(index);
|
||||
Value value = op->getResult(index);
|
||||
if (state->pos == RequantizeState::ON_OUTPUT) {
|
||||
Operation *user = value->getUses().begin().getUser();
|
||||
Operation *user = value.getUses().begin().getUser();
|
||||
if (llvm::isa<TFL::QuantizeOp>(user)) {
|
||||
// The requantize op is inserted between `quantize` and `dequantize` ops.
|
||||
value = user->getResult(0);
|
||||
@ -471,31 +485,31 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
|
||||
RequantizeValue(value, state, op->getLoc());
|
||||
}
|
||||
|
||||
void QuantizationDriver::RequantizeArg(BlockArgument *arg,
|
||||
void QuantizationDriver::RequantizeArg(BlockArgument arg,
|
||||
RequantizeState *state) {
|
||||
Value *value = arg;
|
||||
builder_.setInsertionPointToStart(arg->getOwner());
|
||||
if (value->hasOneUse()) {
|
||||
auto user = value->use_begin().getUser();
|
||||
Value value = arg;
|
||||
builder_.setInsertionPointToStart(arg.getOwner());
|
||||
if (value.hasOneUse()) {
|
||||
auto user = value.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
value = q.output();
|
||||
builder_.setInsertionPoint(arg->getOwner(), ++Block::iterator(user));
|
||||
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
|
||||
}
|
||||
}
|
||||
RequantizeValue(value, state, builder_.getUnknownLoc());
|
||||
}
|
||||
|
||||
void QuantizationDriver::RequantizeValue(Value *value, RequantizeState *state,
|
||||
void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
|
||||
Location loc) {
|
||||
Type new_type;
|
||||
if (state->pos == RequantizeState::ON_INPUT) {
|
||||
Type expressed_type = value->getType();
|
||||
Type expressed_type = value.getType();
|
||||
// The value needs to be requantized. A Quantize op will be created to use
|
||||
// it as the operand and replace its uses.
|
||||
new_type = state->params.castFromExpressedType(expressed_type);
|
||||
} else {
|
||||
Type expressed_type =
|
||||
quant::QuantizedType::castToExpressedType(value->getType());
|
||||
quant::QuantizedType::castToExpressedType(value.getType());
|
||||
if (!expressed_type) return;
|
||||
|
||||
// The value needs to be requantized. A Quantize op will be created to use
|
||||
@ -508,7 +522,7 @@ void QuantizationDriver::RequantizeValue(Value *value, RequantizeState *state,
|
||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||
auto requantize_op =
|
||||
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
|
||||
value->replaceAllUsesWith(requantize_op);
|
||||
value.replaceAllUsesWith(requantize_op);
|
||||
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
|
||||
}
|
||||
|
||||
@ -586,10 +600,10 @@ void QuantizationDriver::PreprocessConstantOps() {
|
||||
auto type = cst.getType().dyn_cast<ShapedType>();
|
||||
if (!type || !type.getElementType().isa<FloatType>()) return;
|
||||
|
||||
Value *value = cst.getResult();
|
||||
Value value = cst.getResult();
|
||||
SmallVector<std::pair<Operation *, int>, 4> bias_users;
|
||||
bool used_as_weight = false;
|
||||
for (auto &use : value->getUses()) {
|
||||
for (auto &use : value.getUses()) {
|
||||
auto spec = GetQuantSpec(use.getOwner());
|
||||
auto biases = spec->biases_params;
|
||||
Operation *user = use.getOwner();
|
||||
@ -629,7 +643,20 @@ void QuantizationDriver::PreprocessConstantOps() {
|
||||
}
|
||||
|
||||
void QuantizationDriver::SetupAllStates() {
|
||||
llvm::DenseMap<Value *, int> value_to_state;
|
||||
llvm::DenseMap<Value, int> value_to_state;
|
||||
|
||||
for (auto arg : fn_.getArguments()) {
|
||||
args_.push_back(arg);
|
||||
Value value = arg;
|
||||
// If the argument is quantized, it should only has one user.
|
||||
if (arg.hasOneUse()) {
|
||||
auto user = value.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
value = q.output();
|
||||
}
|
||||
}
|
||||
InitializeArgState(arg, value, &value_to_state);
|
||||
}
|
||||
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator() ||
|
||||
@ -638,26 +665,24 @@ void QuantizationDriver::SetupAllStates() {
|
||||
work_list_.push_back(op);
|
||||
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
auto *operand = op->getOperand(i);
|
||||
bool is_argument = true;
|
||||
if (auto *inst = operand->getDefiningOp()) {
|
||||
auto operand = op->getOperand(i);
|
||||
if (auto *inst = operand.getDefiningOp()) {
|
||||
// If the operand comes from a tfl.dequantize op, we use the quantized
|
||||
// input of this tfl.dequantize op to set the state.
|
||||
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
|
||||
operand = dq.input();
|
||||
}
|
||||
is_argument = false;
|
||||
}
|
||||
InitializeOperandState(op, i, operand, &value_to_state, is_argument);
|
||||
InitializeOperandState(op, i, operand, &value_to_state);
|
||||
}
|
||||
|
||||
for (int res = 0, e = op->getNumResults(); res != e; ++res) {
|
||||
auto *result = op->getResult(res);
|
||||
Value result = op->getResult(res);
|
||||
// If the result has been quantized, it should only be used by a
|
||||
// tfl.quantize op. For this case, we uses the quantized result to
|
||||
// create the state and mark it immutable.
|
||||
if (result->hasOneUse()) {
|
||||
auto user = result->use_begin().getUser();
|
||||
if (result.hasOneUse()) {
|
||||
auto user = result.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
result = q.output();
|
||||
}
|
||||
@ -746,7 +771,7 @@ bool QuantizationDriver::PropagateParams() {
|
||||
}
|
||||
|
||||
void QuantizationDriver::Finalize() {
|
||||
for (auto *arg : args_) {
|
||||
for (auto arg : args_) {
|
||||
auto &state = GetArgQuantState(arg);
|
||||
auto &requantize = GetArgRequantizeState(arg);
|
||||
if (state.IsEmpty() ||
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
|
||||
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
@ -70,7 +70,7 @@ class FixedResultUniformScale {
|
||||
QuantizedType GetResultQuantizedType(int index) {
|
||||
auto op = this->getOperation();
|
||||
auto result_type =
|
||||
op->getResult(index)->getType().template cast<TensorType>();
|
||||
op->getResult(index).getType().template cast<TensorType>();
|
||||
Builder builder(op->getContext());
|
||||
IntegerType storage_type = builder.getIntegerType(BitWidth);
|
||||
const double scale = static_cast<double>(ScaleMantissa) *
|
||||
|
@ -21,15 +21,15 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -367,7 +367,7 @@ ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
|
||||
static bool PreferResultScale(Operation* op) {
|
||||
int float_operands = 0;
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (auto operand_type = operand->getType().dyn_cast<ShapedType>()) {
|
||||
if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
|
||||
if (operand_type.getElementType().isa<FloatType>()) {
|
||||
if (float_operands++ > 1) return true;
|
||||
}
|
||||
@ -400,22 +400,22 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
quant::StatisticsOp stats_op = all_stats_ops.back();
|
||||
all_stats_ops.pop_back();
|
||||
|
||||
if (auto def = stats_op.arg()->getDefiningOp()) {
|
||||
if (auto def = stats_op.arg().getDefiningOp()) {
|
||||
if (IsStatsRedundant(def, op_quant_spec_getter)) {
|
||||
redundant_stats_ops.insert(stats_op);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto user : stats_op.getResult()->getUsers()) {
|
||||
for (auto user : stats_op.getResult().getUsers()) {
|
||||
// We don't propagate this parameter down if it has multiple operands.
|
||||
// We want to use the result parameter scales instead.
|
||||
|
||||
if (user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
||||
!PreferResultScale(user)) {
|
||||
for (Value* res : user->getResults()) {
|
||||
if (res->hasOneUse()) {
|
||||
for (Value res : user->getResults()) {
|
||||
if (res.hasOneUse()) {
|
||||
if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
|
||||
*res->getUsers().begin())) {
|
||||
*res.getUsers().begin())) {
|
||||
// quantization parameters can be propagated to next_stats
|
||||
redundant_stats_ops.insert(next_stats);
|
||||
// add next_stats to the work list so propagation can
|
||||
@ -440,12 +440,12 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
quant::StatisticsOp stats_op = all_stats_ops.back();
|
||||
all_stats_ops.pop_back();
|
||||
|
||||
if (auto def = stats_op.arg()->getDefiningOp()) {
|
||||
if (auto def = stats_op.arg().getDefiningOp()) {
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
||||
PreferResultScale(def)) {
|
||||
for (auto input : def->getOperands()) {
|
||||
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||
input->getDefiningOp())) {
|
||||
input.getDefiningOp())) {
|
||||
redundant_stats_ops.insert(next_stats);
|
||||
all_stats_ops.push_back(next_stats);
|
||||
}
|
||||
@ -458,7 +458,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
for (auto it : redundant_stats_ops) {
|
||||
if (!llvm::isa<quant::StatisticsOp>(it)) return true;
|
||||
auto stats_op = llvm::cast<quant::StatisticsOp>(it);
|
||||
stats_op.getResult()->replaceAllUsesWith(stats_op.arg());
|
||||
stats_op.getResult().replaceAllUsesWith(stats_op.arg());
|
||||
stats_op.erase();
|
||||
}
|
||||
|
||||
|
@ -23,18 +23,18 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -116,7 +116,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
|
||||
TypeAttr::get(result_type));
|
||||
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
|
||||
op.getResult()->replaceAllUsesWith(dq);
|
||||
op.getResult().replaceAllUsesWith(dq);
|
||||
q.getOperation()->replaceUsesOfWith(dq, op.arg());
|
||||
op.erase();
|
||||
|
||||
@ -161,8 +161,8 @@ struct QuantizationPattern : public RewritePattern {
|
||||
if (op->getNumResults() != 1) {
|
||||
return matchFailure();
|
||||
}
|
||||
Value* quantized_value = op->getResult(0);
|
||||
for (Operation* quantized_op : quantized_value->getUsers()) {
|
||||
Value quantized_value = op->getResult(0);
|
||||
for (Operation* quantized_op : quantized_value.getUsers()) {
|
||||
// If it is requantize op, we shouldn't rewrite this op.
|
||||
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
|
||||
return matchFailure();
|
||||
@ -176,17 +176,17 @@ struct QuantizationPattern : public RewritePattern {
|
||||
|
||||
// Collect all the quantized inputs and "clone" the matched op by these
|
||||
// inputs.
|
||||
SmallVector<Value*, 4> inputs;
|
||||
SmallVector<Value, 4> inputs;
|
||||
inputs.reserve(quantized_op->getNumOperands());
|
||||
for (auto operand : quantized_op->getOperands()) {
|
||||
Type operand_type = operand->getType();
|
||||
Type operand_type = operand.getType();
|
||||
if (operand_type.isa<NoneType>()) {
|
||||
inputs.push_back(operand);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ele_type = operand->getType().cast<TensorType>().getElementType();
|
||||
if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) {
|
||||
auto ele_type = operand.getType().cast<TensorType>().getElementType();
|
||||
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
|
||||
inputs.push_back(op_inst.input());
|
||||
} else if (ele_type.isa<IntegerType>()) {
|
||||
// If the operand is an integer tensor, then it doesn't require the
|
||||
@ -201,13 +201,13 @@ struct QuantizationPattern : public RewritePattern {
|
||||
|
||||
// Collect all the quantized outputs and replace them by the results of
|
||||
// the new quantized op.
|
||||
llvm::SmallDenseMap<Value*, int> outputs_replaced;
|
||||
llvm::SmallDenseMap<Value, int> outputs_replaced;
|
||||
SmallVector<Type, 4> output_types;
|
||||
output_types.reserve(quantized_op->getNumResults());
|
||||
for (auto enumerated_result :
|
||||
llvm::enumerate(quantized_op->getResults())) {
|
||||
Value* result = enumerated_result.value();
|
||||
Type result_type = result->getType();
|
||||
Value result = enumerated_result.value();
|
||||
Type result_type = result.getType();
|
||||
// Add this to the test coverage once we create test ops with none type
|
||||
// results.
|
||||
if (result_type.isa<NoneType>()) {
|
||||
@ -216,20 +216,20 @@ struct QuantizationPattern : public RewritePattern {
|
||||
continue;
|
||||
}
|
||||
Type result_ele_type =
|
||||
result->getType().cast<TensorType>().getElementType();
|
||||
result.getType().cast<TensorType>().getElementType();
|
||||
// If the user is the Quantize op, it must be the only user.
|
||||
if (result->hasOneUse() && llvm::isa<Q>(*result->user_begin())) {
|
||||
auto user = llvm::cast<Q>(*result->user_begin());
|
||||
if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
|
||||
auto user = llvm::cast<Q>(*result.user_begin());
|
||||
outputs_replaced.insert({user.output(), enumerated_result.index()});
|
||||
output_types.push_back(user.getType());
|
||||
} else if (result_ele_type.template isa<IntegerType>()) {
|
||||
// If the result is an integer tensor, then it doesn't require the
|
||||
// D op in the pattern.
|
||||
outputs_replaced.insert({result, enumerated_result.index()});
|
||||
output_types.push_back(result->getType());
|
||||
output_types.push_back(result.getType());
|
||||
} else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) {
|
||||
outputs_replaced.insert({result, enumerated_result.index()});
|
||||
output_types.push_back(result->getType());
|
||||
output_types.push_back(result.getType());
|
||||
} else {
|
||||
return matchFailure();
|
||||
}
|
||||
@ -241,7 +241,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
output_types, quantized_op->getAttrs());
|
||||
Operation* new_op = rewriter.createOperation(new_state);
|
||||
for (auto output : outputs_replaced) {
|
||||
output.getFirst()->replaceAllUsesWith(
|
||||
output.getFirst().replaceAllUsesWith(
|
||||
new_op->getResult(output.getSecond()));
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
// For constant operands, the floating-point constant is duplicated in
|
||||
// case it is quantized.
|
||||
for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
|
||||
auto def = new_op->getOperand(i)->getDefiningOp();
|
||||
auto def = new_op->getOperand(i).getDefiningOp();
|
||||
if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
|
||||
DenseFPElementsAttr attr;
|
||||
if (!matchPattern(q.input(), m_Constant(&attr))) {
|
||||
@ -265,7 +265,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
|
||||
for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
|
||||
if (!quantized_op->getResult(i)
|
||||
->getType()
|
||||
.getType()
|
||||
.cast<ShapedType>()
|
||||
.getElementType()
|
||||
.isa<FloatType>()) {
|
||||
@ -283,13 +283,13 @@ struct QuantizationPattern : public RewritePattern {
|
||||
// Find the Dequantize/Dequantize users of the new op results, and
|
||||
// replace the usage. Then all the floating-point ops are connected.
|
||||
// N.B. the return op will use this floating-point result.
|
||||
for (auto user : new_op->getResult(i)->getUsers()) {
|
||||
for (auto user : new_op->getResult(i).getUsers()) {
|
||||
// Skip the Requantize op, and we know it has a single user.
|
||||
if (llvm::isa<Q>(user)) {
|
||||
user = *user->getResult(0)->getUsers().begin();
|
||||
user = *user->getResult(0).getUsers().begin();
|
||||
}
|
||||
if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
|
||||
dequantize.getResult()->replaceAllUsesWith(
|
||||
dequantize.getResult().replaceAllUsesWith(
|
||||
quantized_op->getResult(i));
|
||||
}
|
||||
}
|
||||
@ -316,7 +316,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
|
||||
PatternMatchResult matchAndRewrite(Q op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Type output_type = op.output()->getType();
|
||||
Type output_type = op.output().getType();
|
||||
auto qtype = QType::getQuantizedElementType(output_type);
|
||||
if (!qtype || qtype.isSigned()) return this->matchFailure();
|
||||
|
||||
|
@ -4,7 +4,7 @@ package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
@ -14,6 +14,6 @@ filegroup(
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "llvm/TableGen/Main.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Operator.h" // TF:local_config_mlir
|
||||
#include "mlir/TableGen/Operator.h" // TF:llvm-project
|
||||
|
||||
using llvm::LessRecord;
|
||||
using llvm::raw_ostream;
|
||||
@ -36,7 +36,7 @@ using mlir::tblgen::Operator;
|
||||
// NOLINTNEXTLINE
|
||||
static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
|
||||
llvm::Regex acc_uniform_trait_regex{"AccumulatorUniformScale<([0-9]*),"};
|
||||
llvm::Regex coeff_index_trait_regex{"AffineOpCoefficient<([0-9]*),"};
|
||||
llvm::Regex coeff_index_trait_regex{"AffineOpCoefficient<(-?[0-9]*),"};
|
||||
llvm::Regex fixed_uniform_trait_regex{
|
||||
"FixedResultUniformScale<([0-9]+).*(true|false)>"};
|
||||
emitSourceFileHeader("Generated Ops Quant Spec Getters", os);
|
||||
|
@ -4,7 +4,7 @@ package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
@ -14,6 +14,6 @@ filegroup(
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
@ -7,10 +7,12 @@ glob_lit_tests(
|
||||
":debug_info_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"pbtxt",
|
||||
"py",
|
||||
# TODO(fengliuai): reenable these tests after the fused loc is
|
||||
# supported in the diagnostic handler.
|
||||
# "py",
|
||||
],
|
||||
)
|
||||
|
||||
@ -31,8 +33,8 @@ filegroup(
|
||||
":saved_model_error",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:not",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
@ -21,6 +21,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
@ -7,7 +7,7 @@ glob_lit_tests(
|
||||
":quant_stats_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"pbtxt",
|
||||
],
|
||||
@ -20,7 +20,7 @@ filegroup(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -38,6 +38,6 @@ versions {
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
|
||||
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: return %0 : tensor<*xi32>
|
||||
# CHECK-NEXT: }
|
||||
|
@ -8,7 +8,7 @@ glob_lit_tests(
|
||||
":extra_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"mlir",
|
||||
"cc",
|
||||
@ -24,7 +24,7 @@ filegroup(
|
||||
":importer_test_min_max",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
||||
@ -51,7 +51,7 @@ tf_native_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -67,6 +67,6 @@ tf_native_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
@ -11,6 +11,8 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
|
||||
// CHECK: %[[EXP:.*]] = "tfl.exp"
|
||||
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
||||
// tfl.neg should not be pruned
|
||||
// CHECK: %[[NEG:.*]] = "tfl.neg"
|
||||
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
|
||||
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
|
||||
return %5 : tensor<4xf32>
|
||||
|
@ -0,0 +1,19 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -output-arrays=mul,exp,div --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// Confirm graph pruning.
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
|
||||
// CHECK: %[[MUL:.*]] = tfl.mul
|
||||
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
|
||||
// CHECK: %[[DIV:.*]] = tfl.div
|
||||
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
|
||||
// CHECK: %[[EXP:.*]] = "tfl.exp"
|
||||
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
||||
// tfl.neg should be pruned
|
||||
// CHECK-NOT: "tfl.neg"
|
||||
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
|
||||
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
|
||||
return %5 : tensor<4xf32>
|
||||
}
|
@ -521,21 +521,30 @@ func @select_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
|
||||
func @select_v2_same_shape(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
|
||||
return %0: tensor<8xf32>
|
||||
|
||||
// CHECK-LABEL: select_v2
|
||||
// CHECK-LABEL: select_v2_same_shape
|
||||
// CHECK: "tfl.select"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @select_v2_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
|
||||
func @select_v2_multidim(%arg0: tensor<3xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
|
||||
return %0: tensor<8x3xf32>
|
||||
|
||||
// CHECK-LABEL: select_v2_multidim
|
||||
// CHECK: "tfl.select"(%arg0, %arg1, %arg2)
|
||||
// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @select_v2_broadcast(%arg0: tensor<4xi1>, %arg1: tensor<3x4xf32>, %arg2: tensor<8x3x4xf32>) -> tensor<8x3x4xf32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor<3x4xf32>, tensor<8x3x4xf32>) -> tensor<8x3x4xf32>
|
||||
return %0: tensor<8x3x4xf32>
|
||||
|
||||
// CHECK-LABEL: select_v2_broadcast
|
||||
// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,7 @@ licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ filegroup(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:not",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,40 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
// CHECK: {
|
||||
// CHECK: version: 3,
|
||||
// CHECK: operator_codes: [ {
|
||||
// CHECK: builtin_code: CUSTOM,
|
||||
// CHECK: custom_code: "SomeOperation"
|
||||
// CHECK: } ],
|
||||
// CHECK: subgraphs: [ {
|
||||
// CHECK: tensors: [ {
|
||||
// CHECK: shape: [ ],
|
||||
// CHECK: type: INT32,
|
||||
// CHECK: buffer: 1,
|
||||
// CHECK: name: "tf.SomeOperation",
|
||||
// CHECK: quantization: {
|
||||
// CHECK-EMPTY
|
||||
// CHECK: }
|
||||
// CHECK: } ],
|
||||
// CHECK: inputs: [ ],
|
||||
// CHECK: outputs: [ 0 ],
|
||||
// CHECK: operators: [ {
|
||||
// CHECK: inputs: [ ],
|
||||
// CHECK: outputs: [ 0 ],
|
||||
// CHECK: custom_options: [ 100, 116, 121, 112, 101, 0, 1, 7, 1, 1, 1, 2, 4, 2, 36, 1 ]
|
||||
// CHECK: } ],
|
||||
// CHECK: name: "main"
|
||||
// CHECK: } ],
|
||||
// CHECK: description: "MLIR Converted.",
|
||||
// CHECK: buffers: [ {
|
||||
// CHECK-EMPTY
|
||||
// CHECK: }, {
|
||||
// CHECK-EMPTY
|
||||
// CHECK: } ]
|
||||
// CHECK: }
|
||||
|
||||
func @main() -> tensor<*xi32> {
|
||||
// Tests that the below type attribute is convertible into the corresponding custom option in flatbuffer.
|
||||
%0 = "tf.SomeOperation"() {dtype = i32 } : () -> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
@ -518,6 +518,20 @@ func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform<i9:f32
|
||||
|
||||
// -----
|
||||
|
||||
func @testMaxPoolingWithArgMax2D(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
|
||||
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testMaxUnpooling2D(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
|
||||
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
|
||||
return %0 : tensor<1x8x8x128xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testLogistic
|
||||
func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
|
||||
^bb0(%arg0: tensor<1x2x3x4x5xbf16>):
|
||||
@ -1942,6 +1956,13 @@ func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %ar
|
||||
|
||||
// -----
|
||||
|
||||
func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
|
||||
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
|
||||
return %0 : tensor<1x64x84x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> {
|
||||
// expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}}
|
||||
%0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32>
|
||||
|
@ -140,22 +140,6 @@ func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1:
|
||||
// CHECK-SAME: fused_activation_function = "RELU6"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: intermOpUsedTwice
|
||||
func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
||||
%cst = constant dense<1.5> : tensor<16xf32>
|
||||
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %0, %1 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00,
|
||||
// CHECK: %cst_0 = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00,
|
||||
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
|
||||
// CHECK: %1 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
|
||||
// CHECK: return %0, %1
|
||||
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoFullyConnected
|
||||
func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
|
||||
@ -167,8 +151,8 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||
|
||||
return %1 : tensor<4x2xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK: return %[[RES]] : tensor<4x2xf32>
|
||||
}
|
||||
@ -233,8 +217,8 @@ func @fuseMulIntoFullyConnectedBroadcast(%arg0: tensor<1x3xf32>) -> tensor<1x2xf
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x2xf32>, tensor<2xf32>) -> tensor<1x2xf32>
|
||||
return %1 : tensor<1x2xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
|
||||
// CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK: return %[[RES]] : tensor<1x2xf32>
|
||||
}
|
||||
@ -249,7 +233,7 @@ func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> te
|
||||
|
||||
return %1 : tensor<4x2xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
|
||||
// CHECK: return %[[RES]] : tensor<4x2xf32>
|
||||
}
|
||||
@ -631,3 +615,18 @@ func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor
|
||||
// CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"}
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotfuseAddIntoConv2d_MultipleUsers
|
||||
func @NotfuseAddIntoConv2d_MultipleUsers(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
||||
%cst = constant dense<1.5> : tensor<16xf32>
|
||||
%cst_1 = constant dense<3.5> : tensor<16xf32>
|
||||
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
%2 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %1, %2 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[tfl_conv2d:[0-9].*]] = "tfl.conv_2d"
|
||||
// CHECK: tfl.add
|
||||
// CHECK-NEXT: tfl.add
|
||||
}
|
||||
|
@ -125,3 +125,21 @@ func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112
|
||||
// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||
// PerTensor: %[[conv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeFullyConnected
|
||||
func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||
%w = constant dense<127.0> : tensor<32x12xf32>
|
||||
%b = constant dense<0.0> : tensor<32xf32>
|
||||
%fc = "tfl.fully_connected"(%arg0, %w, %b) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<32x12xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
return %fc : tensor<1x112x112x32xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<32x12xf32>
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>} : (tensor<32x12xf32>) -> tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<32x12xf32>
|
||||
// CHECK: "tfl.fully_connected"(%arg0, %[[dq]]
|
||||
|
||||
// PerTensor: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<32x12xf32>
|
||||
// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>} : (tensor<32x12xf32>) -> tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>
|
||||
// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<32x12xf32>
|
||||
// PerTensor: "tfl.fully_connected"(%arg0, %[[dq]]
|
||||
}
|
||||
|
@ -379,26 +379,26 @@ func @QuantizeConcatResToAllNoRequantize(tensor<1x2x!quant.uniform<u8:f32, 0.1:1
|
||||
// CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
|
||||
// CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHeCK: return %4 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK: return %4 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeConcatResToAllRequantize
|
||||
func @QuantizeConcatResToAllRequantize(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1:128>> {
|
||||
^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>):
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 2.0:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>
|
||||
%1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>) -> tensor<1x2xf32>
|
||||
%2 = "tfl.concatenation"(%1, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
|
||||
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
|
||||
// CHECK %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 2.000000e+00:128>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 2.000000e+00:128>>
|
||||
// CHECK %1 = "tfl.quantize"(%0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK %2 = "tfl.dequantize"(%1) : (tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<2xf32>
|
||||
// CHECK %3 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK %4 = "tfl.dequantize"(%3) : (tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<2xf32>
|
||||
// CHECK %5 = "tfl.concatenation"(%2, %4) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK %6 = "tfl.quantize"(%5) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK return %6 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
|
||||
// CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>
|
||||
// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
|
||||
// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeConcatResToAllRequantizeArg
|
||||
@ -409,13 +409,13 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform<u8:f32, 2.0:
|
||||
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
|
||||
// CHECK %1 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK %2 = "tfl.dequantize"(%1) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
|
||||
// CHECK %3 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK %4 = "tfl.dequantize"(%3) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
|
||||
// CHECK %5 = "tfl.concatenation"(%2, %4) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK %6 = "tfl.quantize"(%5) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK return %6 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
|
||||
// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
|
||||
// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: RequantizeAlreadyQuantizedModel
|
||||
|
@ -204,8 +204,9 @@ func @QuantizeConcatRequantize(tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>, tens
|
||||
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
|
||||
// CHECK: %[[cc:.*]] = "tfl.concatenation"(%arg0, %[[q]]) {axis = 0 : i32, fused_activation_function = "NONE"}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
|
||||
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
// CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q0]], %[[q1]]) {axis = 0 : i32, fused_activation_function = "NONE"}
|
||||
// CHECK: return %[[cc]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
}
|
||||
|
||||
|
@ -15,11 +15,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -20,11 +20,11 @@ limitations under the License.
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
@ -103,7 +103,7 @@ static int PrintFunctionResultMapping(const std::string &result,
|
||||
i = 0;
|
||||
for (auto output : *subgraph->outputs()) {
|
||||
print_buffer(*subgraph, i, output, [&](int i) {
|
||||
return terminator ? terminator->getOperand(i)->getLoc() : unknown_loc;
|
||||
return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Parser.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Parser.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
@ -21,26 +21,26 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
@ -188,10 +188,10 @@ struct OphintCompositeOp {
|
||||
|
||||
// This function will process the aggregated inputs based on different
|
||||
// strategies like "first", "last", "stack".
|
||||
std::map<int, Value*> GetAggregatedInputs(OpBuilder* builder) {
|
||||
std::map<int, Value*> aggregated_inputs;
|
||||
std::map<int, Value> GetAggregatedInputs(OpBuilder* builder) {
|
||||
std::map<int, Value> aggregated_inputs;
|
||||
for (const auto& kv : inputs) {
|
||||
Value* op_input = nullptr;
|
||||
Value op_input = nullptr;
|
||||
const AggregatedOperand& operand = kv.second;
|
||||
// Dealing with "stack" strategy:
|
||||
// This breaks into two parts:
|
||||
@ -203,9 +203,9 @@ struct OphintCompositeOp {
|
||||
if (operand.ops.size() == 1) {
|
||||
// If ops size is 1, it will be simply expanding dimensions at dim 0.
|
||||
Operation* current_identity_op = operand.ops.begin()->second;
|
||||
Value* input = current_identity_op->getOperand(0);
|
||||
Value input = current_identity_op->getOperand(0);
|
||||
RankedTensorType input_type =
|
||||
input->getType().cast<RankedTensorType>();
|
||||
input.getType().cast<RankedTensorType>();
|
||||
// The Reshape will be {1, (original_shape)}
|
||||
SmallVector<int64_t, 4> reshape_op_shape;
|
||||
reshape_op_shape.push_back(1);
|
||||
@ -234,21 +234,21 @@ struct OphintCompositeOp {
|
||||
|
||||
} else {
|
||||
// Insert a pack op to pack all the inputs together.
|
||||
std::vector<Value*> pack_input_operands;
|
||||
std::vector<Value*> packed_input_consumers;
|
||||
std::vector<Value> pack_input_operands;
|
||||
std::vector<Value> packed_input_consumers;
|
||||
for (int i = 0, e = operand.ops.size(); i < e; ++i) {
|
||||
pack_input_operands.push_back(operand.ops.at(i)->getOperand(0));
|
||||
packed_input_consumers.push_back(operand.ops.at(i)->getResult(0));
|
||||
}
|
||||
// Find the first op that consumes the last value of the aggregated
|
||||
// inputs.
|
||||
Operation* first_use = *(packed_input_consumers.back()->user_begin());
|
||||
Operation* first_use = *(packed_input_consumers.back().user_begin());
|
||||
// The pack reshape will be {N, (original_shape)}
|
||||
SmallVector<int64_t, 4> pack_shape;
|
||||
pack_shape.push_back(pack_input_operands.size());
|
||||
RankedTensorType type = operand.ops.at(0)
|
||||
->getResult(0)
|
||||
->getType()
|
||||
.getType()
|
||||
.cast<RankedTensorType>();
|
||||
for (const auto& dim : type.getShape()) {
|
||||
pack_shape.push_back(dim);
|
||||
@ -288,9 +288,9 @@ struct OphintCompositeOp {
|
||||
const AggregatedOperand& operand = kv.second;
|
||||
if (operand.aggregation == kStrategyStack) {
|
||||
const int output_numer = operand.ops.size();
|
||||
Value* first_output = operand.ops.at(0)->getOperand(0);
|
||||
Value first_output = operand.ops.at(0)->getOperand(0);
|
||||
RankedTensorType first_output_type =
|
||||
first_output->getType().cast<RankedTensorType>();
|
||||
first_output.getType().cast<RankedTensorType>();
|
||||
// The aggregated output shape will be {N, original_shape}.
|
||||
SmallVector<int64_t, 4> shape;
|
||||
shape.push_back(output_numer);
|
||||
@ -300,12 +300,12 @@ struct OphintCompositeOp {
|
||||
aggregated_output_types[kv.first] =
|
||||
RankedTensorType::get(shape, first_output_type.getElementType());
|
||||
} else if (operand.aggregation == kStrategyLast) {
|
||||
Value* last_output =
|
||||
Value last_output =
|
||||
operand.ops.at(operand.ops.size() - 1)->getOperand(0);
|
||||
aggregated_output_types[kv.first] = last_output->getType();
|
||||
aggregated_output_types[kv.first] = last_output.getType();
|
||||
} else {
|
||||
Value* first_output = operand.ops.at(0)->getOperand(0);
|
||||
aggregated_output_types[kv.first] = first_output->getType();
|
||||
Value first_output = operand.ops.at(0)->getOperand(0);
|
||||
aggregated_output_types[kv.first] = first_output.getType();
|
||||
}
|
||||
}
|
||||
return aggregated_output_types;
|
||||
@ -329,7 +329,7 @@ struct OphintCompositeOp {
|
||||
Operation* first_output = operand.ops.at(0);
|
||||
Location insert_loc = first_output->getLoc();
|
||||
SmallVector<Type, 4> unpack_output_types(
|
||||
output_number, first_output->getOperand(0)->getType());
|
||||
output_number, first_output->getOperand(0).getType());
|
||||
|
||||
builder->setInsertionPoint(first_output);
|
||||
Operation* unpack_op = builder->create<TFL::UnpackOp>(
|
||||
@ -404,7 +404,7 @@ void PreprocessTopoSortGraph(
|
||||
// should only count as one.
|
||||
llvm::DenseSet<Operation*> input_ops;
|
||||
for (int i = 0; i < op.getNumOperands(); ++i) {
|
||||
Operation* input_op = op.getOperand(i)->getDefiningOp();
|
||||
Operation* input_op = op.getOperand(i).getDefiningOp();
|
||||
if (input_op) input_ops.insert(input_op);
|
||||
}
|
||||
if (input_ops.empty()) {
|
||||
@ -507,15 +507,15 @@ LogicalResult TopoSortOperations(OpBuilder* builder) {
|
||||
|
||||
Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
|
||||
Operation* insert_before_op,
|
||||
const std::map<int, Value*>& inputs,
|
||||
const std::map<int, Value>& inputs,
|
||||
const std::map<int, Type>& output_types,
|
||||
OpBuilder* builder, ModuleOp* module_op) {
|
||||
SmallVector<Type, 4> input_types;
|
||||
SmallVector<Value*, 4> input_values;
|
||||
SmallVector<Value, 4> input_values;
|
||||
SmallVector<int, 4> input_indexes;
|
||||
for (const auto& kv : inputs) {
|
||||
Value* input = kv.second;
|
||||
input_types.push_back(input->getType());
|
||||
Value input = kv.second;
|
||||
input_types.push_back(input.getType());
|
||||
input_values.push_back(input);
|
||||
input_indexes.push_back(kv.first);
|
||||
}
|
||||
@ -588,8 +588,8 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
|
||||
llvm::DenseSet<Operation*> reachable_ops;
|
||||
std::queue<Operation*> ops_queue;
|
||||
for (auto& input_op : input_ops) {
|
||||
for (Value* value : input_op->getOperands()) {
|
||||
Operation* op = value->getDefiningOp();
|
||||
for (Value value : input_op->getOperands()) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (op != nullptr) ops_queue.push(op);
|
||||
}
|
||||
}
|
||||
@ -598,8 +598,8 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
|
||||
Operation* current_op = ops_queue.front();
|
||||
ops_queue.pop();
|
||||
reachable_ops.insert(current_op);
|
||||
for (Value* value : current_op->getOperands()) {
|
||||
Operation* upstream_op = value->getDefiningOp();
|
||||
for (Value value : current_op->getOperands()) {
|
||||
Operation* upstream_op = value.getDefiningOp();
|
||||
// Not visited, put it into the queue.
|
||||
if (upstream_op != nullptr &&
|
||||
!llvm::is_contained(reachable_ops, upstream_op)) {
|
||||
@ -625,7 +625,7 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
||||
BfsForReachableOps(ophint_composite_op.GetAllOutputOps());
|
||||
|
||||
// Step 3, deal with inputs aggregation strategies.
|
||||
const std::map<int, Value*>& aggregated_inputs =
|
||||
const std::map<int, Value>& aggregated_inputs =
|
||||
ophint_composite_op.GetAggregatedInputs(builder);
|
||||
|
||||
// Step 4, get aggregated output types.
|
||||
@ -642,7 +642,7 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
||||
aggregated_inputs, aggregated_output_types, builder, module_op);
|
||||
|
||||
for (const auto& kv : aggregated_inputs) {
|
||||
Operation* op = kv.second->getDefiningOp();
|
||||
Operation* op = kv.second.getDefiningOp();
|
||||
if (op == nullptr) return failure();
|
||||
op->moveBefore(fused_op);
|
||||
}
|
||||
|
@ -15,23 +15,23 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -92,18 +92,18 @@ LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op,
|
||||
if (call_op.getNumResults() != 1) return failure();
|
||||
|
||||
// Inputs is indexed at 0.
|
||||
Value* input = call_op.getOperand(0);
|
||||
Value input = call_op.getOperand(0);
|
||||
// Input_weight is indexed at 1.
|
||||
Value* weight = call_op.getOperand(1);
|
||||
Value weight = call_op.getOperand(1);
|
||||
// Recurrent_weight is indexed at 2.
|
||||
Value* recurrent_weight = call_op.getOperand(2);
|
||||
Value recurrent_weight = call_op.getOperand(2);
|
||||
// Bias is indexed at 3.
|
||||
Value* bias = call_op.getOperand(3);
|
||||
Value bias = call_op.getOperand(3);
|
||||
// Hidden_state is indexed at 4.
|
||||
Value* hidden_state = call_op.getOperand(4);
|
||||
Value hidden_state = call_op.getOperand(4);
|
||||
|
||||
// Build Output.
|
||||
auto output_type = call_op.getResult(0)->getType();
|
||||
auto output_type = call_op.getResult(0).getType();
|
||||
|
||||
// Currently, ophinted RNN only supports time_major = True.
|
||||
const bool time_major = true;
|
||||
@ -127,7 +127,7 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
|
||||
auto input_index_attr = composite_func_op.getAttr(kTfLiteFunctionInputIndex)
|
||||
.cast<ArrayAttr>()
|
||||
.getValue();
|
||||
llvm::DenseMap<int, Value*> fused_ops_index_to_call_op_args;
|
||||
llvm::DenseMap<int, Value> fused_ops_index_to_call_op_args;
|
||||
|
||||
for (int i = 0; i < call_op.getNumOperands(); ++i) {
|
||||
int input_index = input_index_attr[i].cast<IntegerAttr>().getInt();
|
||||
@ -139,7 +139,7 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
|
||||
|
||||
// We encounter some optional arguments not filled, so we need to create an
|
||||
// empty Value.
|
||||
Value* none_value;
|
||||
Value none_value;
|
||||
if (call_op.getNumOperands() <
|
||||
kUnidirectionalSequenceLSTMOpTotalIArgumentNum) {
|
||||
builder->setInsertionPoint(call_op.getOperation());
|
||||
@ -148,7 +148,7 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
|
||||
}
|
||||
|
||||
// Prepare all operands for the UnidirectionalSequenceLSTMOp.
|
||||
SmallVector<Value*, kUnidirectionalSequenceLSTMOpTotalIArgumentNum> operands;
|
||||
SmallVector<Value, kUnidirectionalSequenceLSTMOpTotalIArgumentNum> operands;
|
||||
for (int i = 0; i < kUnidirectionalSequenceLSTMOpTotalIArgumentNum; ++i) {
|
||||
auto operand_it = fused_ops_index_to_call_op_args.find(i);
|
||||
if (operand_it == fused_ops_index_to_call_op_args.end()) {
|
||||
@ -169,12 +169,12 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
|
||||
if (call_op.getNumResults() > 1) {
|
||||
for (int i = 0; i < call_op.getNumResults() - 1; ++i) {
|
||||
// This one should not be used.
|
||||
Value* unused_output = call_op.getResult(i);
|
||||
if (!unused_output->use_empty()) return failure();
|
||||
Value unused_output = call_op.getResult(i);
|
||||
if (!unused_output.use_empty()) return failure();
|
||||
}
|
||||
}
|
||||
output_types.push_back(
|
||||
call_op.getResult(call_op.getNumResults() - 1)->getType());
|
||||
call_op.getResult(call_op.getNumResults() - 1).getType());
|
||||
|
||||
// Prepare attributes.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
@ -206,11 +206,11 @@ LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name,
|
||||
LogicalResult build_fused_op_result = BuildUnidirectionalSequenceLSTMOp(
|
||||
composite_func_op, call_op, builder, &fused_op);
|
||||
if (failed(build_fused_op_result)) return build_fused_op_result;
|
||||
Value* call_output = call_op.getResult(call_op.getNumResults() - 1);
|
||||
if (call_output->getType() != fused_op->getResult(0)->getType()) {
|
||||
Value call_output = call_op.getResult(call_op.getNumResults() - 1);
|
||||
if (call_output.getType() != fused_op->getResult(0).getType()) {
|
||||
return failure();
|
||||
}
|
||||
call_output->replaceAllUsesWith(fused_op->getResult(0));
|
||||
call_output.replaceAllUsesWith(fused_op->getResult(0));
|
||||
} else { // If we support more fused op, we should add the conversion here.
|
||||
return failure();
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
|
||||
// Use the tensor type information from $0 and convert min $1, max $2 and
|
||||
// numBits $3 and narrowRange $4 to a QuantizedType.
|
||||
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
|
||||
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
|
||||
"GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
|
||||
|
||||
// Converts an integer attribute $0 to 32-bit with builder.
|
||||
def convertIntAttrTo32Bit : NativeCodeCall<
|
||||
@ -49,6 +49,11 @@ def convertIntAttrTo32Bit : NativeCodeCall<
|
||||
def ExtractSingleElementAsInteger : NativeCodeCall<
|
||||
"ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">;
|
||||
|
||||
// Checks whether the given operation has static shapes and same shapes of all inputs.
|
||||
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
|
||||
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
|
||||
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Nullary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -145,10 +150,9 @@ def : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
|
||||
def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
|
||||
def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
|
||||
def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
|
||||
// TODO(jpienaar): this is not true for all selects, TF's select supports rank 0
|
||||
// condition
|
||||
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
|
||||
def : Pat<(TF_SelectV2Op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
|
||||
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>;
|
||||
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>;
|
||||
def : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>;
|
||||
def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>;
|
||||
def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
|
||||
|
@ -28,15 +28,15 @@ limitations under the License.
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -66,6 +66,28 @@ struct LegalizeTF : public FunctionPass<LegalizeTF> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Returns true if all tensor value in `values` has static shape and same shape.
|
||||
bool HasSameStaticShapes(Operation* op) {
|
||||
auto values = op->getOperands();
|
||||
int index = 0;
|
||||
ArrayRef<int64_t> shape;
|
||||
for (Value value : values) {
|
||||
auto shaped_type = value.getType().dyn_cast<ShapedType>();
|
||||
if (!shaped_type && !shaped_type.hasStaticShape()) {
|
||||
return false;
|
||||
}
|
||||
if (index == 0) {
|
||||
shape = shaped_type.getShape();
|
||||
} else {
|
||||
if (shape != shaped_type.getShape()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
++index;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
|
||||
|
||||
#define DECL_CONVERT_OP(tf_op) \
|
||||
@ -100,7 +122,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
|
||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
||||
|
||||
auto values = tf_concat_op.values();
|
||||
auto output_type = tf_concat_op.output()->getType();
|
||||
auto output_type = tf_concat_op.output().getType();
|
||||
// Extract axis attribute from constant concat_dims tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
|
||||
@ -119,7 +141,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
|
||||
|
||||
auto values = tf_concat_op.values();
|
||||
auto output_type = tf_concat_op.output()->getType();
|
||||
auto output_type = tf_concat_op.output().getType();
|
||||
// Extract axis attribute from constant axis tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
|
||||
@ -145,7 +167,7 @@ PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
|
||||
if (tf_matmul_op.transpose_a()) return matchFailure();
|
||||
if (!tf_matmul_op.transpose_b()) return matchFailure();
|
||||
|
||||
Type output_type = tf_matmul_op.getResult()->getType();
|
||||
Type output_type = tf_matmul_op.getResult().getType();
|
||||
// TODO(jpienaar): Follow up post shuffle discussion.
|
||||
auto no_input = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
|
||||
@ -161,8 +183,8 @@ PatternMatchResult ConvertTFPackOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_pack_op = cast<TF::PackOp>(op);
|
||||
|
||||
SmallVector<Value*, 4> values(tf_pack_op.values());
|
||||
auto output_type = tf_pack_op.output()->getType();
|
||||
SmallVector<Value, 4> values(tf_pack_op.values());
|
||||
auto output_type = tf_pack_op.output().getType();
|
||||
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
|
||||
// Axis can be negative.
|
||||
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue());
|
||||
@ -176,10 +198,10 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_reshape_op = cast<TF::ReshapeOp>(op);
|
||||
|
||||
auto* input = tf_reshape_op.tensor();
|
||||
auto* shape = tf_reshape_op.shape();
|
||||
auto input = tf_reshape_op.tensor();
|
||||
auto shape = tf_reshape_op.shape();
|
||||
|
||||
ShapedType shape_type = shape->getType().cast<ShapedType>();
|
||||
ShapedType shape_type = shape.getType().cast<ShapedType>();
|
||||
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
|
||||
if (!shape_type.getElementType().isInteger(32)) {
|
||||
auto new_shape = shape_type.getShape();
|
||||
@ -191,7 +213,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
rewriter.getBoolAttr(false))
|
||||
.y();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output()->getType(),
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
|
||||
input, shape);
|
||||
return matchSuccess();
|
||||
}
|
||||
@ -200,7 +222,7 @@ PatternMatchResult ConvertTFSplitOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_split_op = cast<TF::SplitOp>(op);
|
||||
|
||||
auto output_types = functional::map([](Value* v) { return v->getType(); },
|
||||
auto output_types = functional::map([](Value v) { return v.getType(); },
|
||||
tf_split_op.output());
|
||||
// Number of splits cannot be negative.
|
||||
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
|
||||
@ -215,7 +237,7 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_splitv_op = cast<TF::SplitVOp>(op);
|
||||
|
||||
auto output_types = functional::map([](Value* v) { return v->getType(); },
|
||||
auto output_types = functional::map([](Value v) { return v.getType(); },
|
||||
tf_splitv_op.output());
|
||||
// Number of splits cannot be negative.
|
||||
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
|
||||
@ -226,13 +248,13 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
Value* PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
|
||||
Value* attribute,
|
||||
Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
|
||||
Value attribute,
|
||||
ArrayRef<int32_t> padding_val, int* mask) {
|
||||
DenseIntElementsAttr dense_elem_attr;
|
||||
SmallVector<int32_t, 8> padded_val;
|
||||
|
||||
auto ranked_attr_type = attribute->getType().dyn_cast<RankedTensorType>();
|
||||
auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked_attr_type ||
|
||||
!matchPattern(attribute, m_Constant(&dense_elem_attr))) {
|
||||
// If the input attribute is neither ranked type nor constant, we
|
||||
@ -258,14 +280,14 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
|
||||
auto ranked_input_type =
|
||||
tf_strided_slice_op.input()->getType().dyn_cast<RankedTensorType>();
|
||||
tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked_input_type) {
|
||||
// If input is not a ranked tensor, we can't deduce the padding dimensions
|
||||
// from it, so we just do a plain conversion here.
|
||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
||||
op, tf_strided_slice_op.output()->getType(),
|
||||
tf_strided_slice_op.input(), tf_strided_slice_op.begin(),
|
||||
tf_strided_slice_op.end(), tf_strided_slice_op.strides(),
|
||||
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
||||
tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
|
||||
tf_strided_slice_op.strides(),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.begin_mask().getSExtValue()),
|
||||
rewriter.getI32IntegerAttr(
|
||||
@ -283,20 +305,20 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
// Pad `begin` array with zero values and update the `begin_mask`.
|
||||
SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0);
|
||||
int begin_mask = tf_strided_slice_op.begin_mask().getSExtValue();
|
||||
Value* padded_begin = PadStridedSliceAttributeArray(
|
||||
Value padded_begin = PadStridedSliceAttributeArray(
|
||||
op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask);
|
||||
// Pad `end` array with `input_shape` and update the `end_mask`.
|
||||
int end_mask = tf_strided_slice_op.end_mask().getSExtValue();
|
||||
auto input_shape = ranked_input_type.getShape();
|
||||
SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end());
|
||||
Value* padded_end = PadStridedSliceAttributeArray(
|
||||
Value padded_end = PadStridedSliceAttributeArray(
|
||||
op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask);
|
||||
// Pad `strides` array with ones.
|
||||
SmallVector<int32_t, 8> strides_pad_val(num_input_dims, 1);
|
||||
Value* padded_strides = PadStridedSliceAttributeArray(
|
||||
Value padded_strides = PadStridedSliceAttributeArray(
|
||||
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
|
||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
||||
op, tf_strided_slice_op.output()->getType(), tf_strided_slice_op.input(),
|
||||
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
||||
padded_begin, padded_end, padded_strides,
|
||||
rewriter.getI32IntegerAttr(begin_mask),
|
||||
rewriter.getI32IntegerAttr(end_mask),
|
||||
@ -313,8 +335,8 @@ PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
||||
|
||||
auto* input = tf_unpack_op.value();
|
||||
auto output_types = functional::map([](Value* v) { return v->getType(); },
|
||||
auto input = tf_unpack_op.value();
|
||||
auto output_types = functional::map([](Value v) { return v.getType(); },
|
||||
tf_unpack_op.output());
|
||||
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
|
||||
// Axis can be negative.
|
||||
@ -338,7 +360,7 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
|
||||
if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
|
||||
|
||||
auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
|
||||
auto output_type = tf_matrix_diag_v2_or_v3_op.output()->getType();
|
||||
auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType();
|
||||
|
||||
// Extract k constant tensor and check value = 0.
|
||||
ElementsAttr k;
|
||||
@ -478,7 +500,7 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
|
||||
auto status_or_const_op = CreateConstOpWithSingleValue(
|
||||
&rewriter, op->getLoc(),
|
||||
tf_reciprocal_op.x()->getType().cast<ShapedType>(), 1);
|
||||
tf_reciprocal_op.x().getType().cast<ShapedType>(), 1);
|
||||
if (!status_or_const_op.ok()) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -50,13 +50,13 @@ struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
|
||||
|
||||
// Create LSTM gates with different weights for input, recurrent and
|
||||
// cell state, and also the layer normalization parameters.
|
||||
Operation* CreateGate(Location loc, Value* in, Value* in_w, Value* rec,
|
||||
Value* rec_w,
|
||||
llvm::Optional<std::pair<Value*, Value*>> cell,
|
||||
Value* ln_w, Value* ln_bias, OpBuilder* builder);
|
||||
Operation* CreateGate(Location loc, Value in, Value in_w, Value rec,
|
||||
Value rec_w,
|
||||
llvm::Optional<std::pair<Value, Value>> cell,
|
||||
Value ln_w, Value ln_bias, OpBuilder* builder);
|
||||
|
||||
Operation* CreateLayerNorm(Location loc, Value* in, Value* ln_w,
|
||||
Value* ln_bias, OpBuilder* builder);
|
||||
Operation* CreateLayerNorm(Location loc, Value in, Value ln_w, Value ln_bias,
|
||||
OpBuilder* builder);
|
||||
|
||||
// Add the internal implementation of the LSTM to its regions.
|
||||
void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder);
|
||||
@ -71,7 +71,7 @@ struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
|
||||
|
||||
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
|
||||
Type expressed_type =
|
||||
lstm.input()->getType().cast<ShapedType>().getElementType();
|
||||
lstm.input().getType().cast<ShapedType>().getElementType();
|
||||
Type int8_storage_type = builder->getIntegerType(8);
|
||||
Type int16_storage_type = builder->getIntegerType(16);
|
||||
auto flag = quant::QuantizationFlags::FlagValue::Signed;
|
||||
@ -88,12 +88,12 @@ void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
|
||||
auto any_int16 = quant::AnyQuantizedType::get(
|
||||
flag, int16_storage_type, expressed_type, int16_min, int16_max);
|
||||
|
||||
int8 = any_int8.castFromExpressedType(lstm.input()->getType());
|
||||
int16 = any_int16.castFromExpressedType(lstm.input()->getType());
|
||||
int8 = any_int8.castFromExpressedType(lstm.input().getType());
|
||||
int16 = any_int16.castFromExpressedType(lstm.input().getType());
|
||||
}
|
||||
|
||||
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in,
|
||||
Value* ln_w, Value* ln_bias,
|
||||
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in,
|
||||
Value ln_w, Value ln_bias,
|
||||
OpBuilder* builder) {
|
||||
// Note that l2_normalization and add ops here are not the execution kernel
|
||||
// implementation for layer_normalization and we just want to use them to
|
||||
@ -105,8 +105,8 @@ Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in,
|
||||
}
|
||||
|
||||
Operation* LoadQuantizationRecipe::CreateGate(
|
||||
Location loc, Value* in, Value* in_w, Value* rec, Value* rec_w,
|
||||
llvm::Optional<std::pair<Value*, Value*>> cell, Value* ln_w, Value* ln_bias,
|
||||
Location loc, Value in, Value in_w, Value rec, Value rec_w,
|
||||
llvm::Optional<std::pair<Value, Value>> cell, Value ln_w, Value ln_bias,
|
||||
OpBuilder* builder) {
|
||||
auto s1 = builder->create<FullyConnectedOp>(loc, int16, in, in_w, none_cst,
|
||||
none_af, fc_format, keep_dims);
|
||||
@ -119,13 +119,13 @@ Operation* LoadQuantizationRecipe::CreateGate(
|
||||
cell.getValue().second, none_af);
|
||||
s4 = builder->create<AddNOp>(
|
||||
loc, int16,
|
||||
llvm::ArrayRef<Value*>(
|
||||
llvm::ArrayRef<Value>(
|
||||
{*s1.output().begin(), *s2.output().begin(), s3.output()}));
|
||||
|
||||
} else {
|
||||
s4 = builder->create<AddNOp>(
|
||||
loc, int16,
|
||||
llvm::ArrayRef<Value*>({*s1.output().begin(), *s2.output().begin()}));
|
||||
llvm::ArrayRef<Value>({*s1.output().begin(), *s2.output().begin()}));
|
||||
}
|
||||
|
||||
auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder);
|
||||
@ -144,22 +144,20 @@ void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
|
||||
region.push_back(new Block);
|
||||
builder->setInsertionPointToEnd(®ion.front());
|
||||
Location loc = lstm.getLoc();
|
||||
Type int32_type = builder->getIntegerType(32);
|
||||
Type int32_tensor = UnrankedTensorType::get(int32_type);
|
||||
none_cst = builder->create<ConstantOp>(loc, builder->getNoneType(),
|
||||
builder->getUnitAttr());
|
||||
|
||||
auto input_gate = CreateGate(
|
||||
loc, lstm.input(), lstm.input_to_input_weights(),
|
||||
lstm.input_activation_state(), lstm.recurrent_to_input_weights(),
|
||||
llvm::Optional<std::pair<Value*, Value*>>(
|
||||
llvm::Optional<std::pair<Value, Value>>(
|
||||
{lstm.input_cell_state(), lstm.cell_to_input_weights()}),
|
||||
lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder);
|
||||
|
||||
auto forget_gate = CreateGate(
|
||||
loc, lstm.input(), lstm.input_to_forget_weights(),
|
||||
lstm.input_activation_state(), lstm.recurrent_to_forget_weights(),
|
||||
llvm::Optional<std::pair<Value*, Value*>>(
|
||||
llvm::Optional<std::pair<Value, Value>>(
|
||||
{lstm.input_cell_state(), lstm.cell_to_forget_weights()}),
|
||||
lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder);
|
||||
|
||||
@ -179,7 +177,7 @@ void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
|
||||
auto output_gate = CreateGate(
|
||||
loc, lstm.input(), lstm.input_to_output_weights(),
|
||||
lstm.input_activation_state(), lstm.recurrent_to_output_weights(),
|
||||
llvm::Optional<std::pair<Value*, Value*>>(
|
||||
llvm::Optional<std::pair<Value, Value>>(
|
||||
{new_cell, lstm.cell_to_output_weights()}),
|
||||
lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder);
|
||||
|
||||
|
@ -29,28 +29,28 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
@ -84,7 +84,7 @@ struct LowerStaticTensorListPass
|
||||
TensorListPatternRewriter *rewriter);
|
||||
};
|
||||
|
||||
Value *CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
|
||||
Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
|
||||
ArrayRef<int64_t> shape, int32_t val) {
|
||||
RankedTensorType type =
|
||||
RankedTensorType::get(shape, rewriter->getIntegerType(32));
|
||||
@ -93,9 +93,9 @@ Value *CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
|
||||
return rewriter->create<ConstantOp>(loc, type, attr);
|
||||
}
|
||||
|
||||
Value *CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
|
||||
Value *shape_tensor, int32_t val) {
|
||||
Value *scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
|
||||
Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
|
||||
Value shape_tensor, int32_t val) {
|
||||
Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
|
||||
return rewriter->create<TF::FillOp>(
|
||||
loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)),
|
||||
shape_tensor, scalar_val);
|
||||
@ -131,32 +131,32 @@ Type GetTensorTypeForTensorList(Type element_type, TF::VariantType handle_dtype,
|
||||
// Requires that `start_index` and `size` are scalar tensors and
|
||||
// `item_position_shape` is a 1-D tensor with only one element equal to the rank
|
||||
// of an item in the tensorlist.
|
||||
TF::SliceOp CreateSliceOpForTensorList(Location loc, Value *input_list,
|
||||
Value *start_index, Value *size,
|
||||
Value *item_rank, Type result_type,
|
||||
TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
|
||||
Value start_index, Value size,
|
||||
Value item_rank, Type result_type,
|
||||
PatternRewriter *rewriter) {
|
||||
// Create the start position of slice. This is done by concatenating
|
||||
// `start_index` and `partial_start_position` together.
|
||||
IntegerType shape_dtype = rewriter->getIntegerType(32);
|
||||
RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
|
||||
Value *partial_start_position =
|
||||
Value partial_start_position =
|
||||
CreateI32SplatTensor(loc, rewriter, item_rank, 0);
|
||||
Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
|
||||
Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
|
||||
RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
|
||||
auto expanded_start_index = rewriter->create<TF::ExpandDimsOp>(
|
||||
loc, vector_type, start_index, scalar_zero);
|
||||
auto start_position = rewriter->create<TF::ConcatOp>(
|
||||
loc, position_type, scalar_zero,
|
||||
ArrayRef<Value *>({expanded_start_index, partial_start_position}));
|
||||
ArrayRef<Value>({expanded_start_index, partial_start_position}));
|
||||
|
||||
// Create the slice size tensor. This is done by concatenating `size` and
|
||||
// `partial_size`.
|
||||
auto size_leading_dim =
|
||||
rewriter->create<TF::ExpandDimsOp>(loc, vector_type, size, scalar_zero);
|
||||
Value *partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
|
||||
Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
|
||||
auto slice_size = rewriter->create<TF::ConcatOp>(
|
||||
loc, position_type, scalar_zero,
|
||||
ArrayRef<Value *>({size_leading_dim, partial_size}));
|
||||
ArrayRef<Value>({size_leading_dim, partial_size}));
|
||||
|
||||
return rewriter->create<TF::SliceOp>(loc, result_type, input_list,
|
||||
start_position, slice_size);
|
||||
@ -180,31 +180,31 @@ struct ConvertTensorListSetItem : public ConversionPattern {
|
||||
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
|
||||
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListSetItemOp>(operation);
|
||||
Location loc = op.getLoc();
|
||||
Value *input = operands[0];
|
||||
Value *index = operands[1];
|
||||
Value *item = operands[2];
|
||||
Value input = operands[0];
|
||||
Value index = operands[1];
|
||||
Value item = operands[2];
|
||||
|
||||
IntegerType shape_dtype = rewriter.getIntegerType(32);
|
||||
auto item_rank = rewriter.create<TF::RankOp>(
|
||||
loc, RankedTensorType::get({}, shape_dtype), item);
|
||||
Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
|
||||
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
|
||||
|
||||
// Calculate `index` + 1, which is used to generate the start position for
|
||||
// the second slice op.
|
||||
auto suffix_start =
|
||||
rewriter.create<TF::AddOp>(loc, index->getType(), index,
|
||||
rewriter.create<TF::AddOp>(loc, index.getType(), index,
|
||||
CreateI32SplatConst(loc, &rewriter, {}, 1));
|
||||
|
||||
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
|
||||
loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
|
||||
// Create two slice ops.
|
||||
Type element_type = input->getType().cast<TensorType>().getElementType();
|
||||
Type element_type = input.getType().cast<TensorType>().getElementType();
|
||||
UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
|
||||
Value *scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
|
||||
Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
|
||||
TF::SliceOp slice1 =
|
||||
CreateSliceOpForTensorList(loc, /*input_list=*/input,
|
||||
/*start_index=*/scalar_zero,
|
||||
@ -225,8 +225,8 @@ struct ConvertTensorListSetItem : public ConversionPattern {
|
||||
|
||||
// Concatenate three parts together to generate the final result.
|
||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||
op, input->getType(), scalar_zero,
|
||||
ArrayRef<Value *>({slice1, expanded_item, slice2}));
|
||||
op, input.getType(), scalar_zero,
|
||||
ArrayRef<Value>({slice1, expanded_item, slice2}));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
@ -241,14 +241,14 @@ struct ConvertTensorListInitOp : public ConversionPattern {
|
||||
|
||||
// Create and return a 1-d tensor with exactly one element equal to the number
|
||||
// of list elements to initialize the output tensor list with.
|
||||
virtual Value *GetNumElements(OpT op, ArrayRef<Value *> operands,
|
||||
virtual Value GetNumElements(OpT op, ArrayRef<Value> operands,
|
||||
PatternRewriter *rewriter) const = 0;
|
||||
|
||||
// Rewrites the original op into `tf.fill`. The result tensor shape is
|
||||
// [num_element, element_shape]. All the values in the result tensor will be
|
||||
// initialized to 0.
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
OpT op = llvm::cast<OpT>(operation);
|
||||
|
||||
@ -263,8 +263,8 @@ struct ConvertTensorListInitOp : public ConversionPattern {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
Value *element_shape = operands[0];
|
||||
Type shape_dtype = getElementTypeOrSelf(element_shape->getType());
|
||||
Value element_shape = operands[0];
|
||||
Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
|
||||
|
||||
DenseIntElementsAttr dense_elem_attr;
|
||||
if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
||||
@ -297,11 +297,10 @@ struct ConvertTensorListInitOp : public ConversionPattern {
|
||||
new_element_shape_values.push_back(dim_value);
|
||||
}
|
||||
|
||||
auto attr =
|
||||
DenseIntElementsAttr::get(element_shape->getType().cast<ShapedType>(),
|
||||
new_element_shape_values);
|
||||
auto attr = DenseIntElementsAttr::get(
|
||||
element_shape.getType().cast<ShapedType>(), new_element_shape_values);
|
||||
auto new_element_shape = rewriter.create<ConstantOp>(
|
||||
op.getLoc(), element_shape->getType(), attr);
|
||||
op.getLoc(), element_shape.getType(), attr);
|
||||
element_shape = new_element_shape;
|
||||
}
|
||||
|
||||
@ -330,11 +329,11 @@ struct ConvertTensorListInitOp : public ConversionPattern {
|
||||
Location loc = op.getLoc();
|
||||
// Add number of elements as the prefix to the element shape to get shape of
|
||||
// the output tensor.
|
||||
Value *leading_dim = GetNumElements(op, operands, &rewriter);
|
||||
Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
|
||||
Value leading_dim = GetNumElements(op, operands, &rewriter);
|
||||
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
|
||||
auto list_shape = rewriter.create<TF::ConcatOp>(
|
||||
loc, shape_type, scalar_zero,
|
||||
ArrayRef<Value *>({leading_dim, element_shape}));
|
||||
ArrayRef<Value>({leading_dim, element_shape}));
|
||||
|
||||
// Create a zero-initialized constant tensor that has the same type
|
||||
// as specified by element_dtype.
|
||||
@ -352,11 +351,11 @@ struct ConvertTensorListReserve
|
||||
explicit ConvertTensorListReserve(MLIRContext *context)
|
||||
: ConvertTensorListInitOp(context) {}
|
||||
|
||||
Value *GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value *> operands,
|
||||
Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
|
||||
PatternRewriter *rewriter) const override {
|
||||
Value *scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
|
||||
Type shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
|
||||
Value *num_elements = operands[1];
|
||||
Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
|
||||
Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
|
||||
Value num_elements = operands[1];
|
||||
return rewriter->create<TF::ExpandDimsOp>(
|
||||
op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
|
||||
scalar_zero);
|
||||
@ -371,7 +370,7 @@ struct ConvertEmptyTensorList
|
||||
explicit ConvertEmptyTensorList(MLIRContext *context)
|
||||
: ConvertTensorListInitOp(context) {}
|
||||
|
||||
Value *GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value *> operands,
|
||||
Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands,
|
||||
PatternRewriter *rewriter) const override {
|
||||
return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0);
|
||||
}
|
||||
@ -383,23 +382,23 @@ struct ConvertTensorListPushBack : public ConversionPattern {
|
||||
context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value *> operands,
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
|
||||
Value *input_handle = operands[0];
|
||||
Value *item = operands[1];
|
||||
Value input_handle = operands[0];
|
||||
Value item = operands[1];
|
||||
|
||||
// Expand the shape of the item so that it will have rank same as the input
|
||||
// tensor and it is compatible for the Concat Op.
|
||||
Type expanded_item_type =
|
||||
PrependLeadingDimIfRanked(1, item->getType(), &rewriter);
|
||||
Value *scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
|
||||
PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
|
||||
Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
|
||||
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), expanded_item_type, item, scalar_zero);
|
||||
|
||||
Type elem_type = getElementTypeOrSelf(item);
|
||||
auto handle_dtype =
|
||||
getElementTypeOrSelf(push_back_op.output_handle()->getType())
|
||||
getElementTypeOrSelf(push_back_op.output_handle().getType())
|
||||
.cast<TF::VariantType>();
|
||||
Type result_type =
|
||||
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
|
||||
@ -408,7 +407,7 @@ struct ConvertTensorListPushBack : public ConversionPattern {
|
||||
// get a tensor equivalent to the TensorList generated by this op.
|
||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||
push_back_op, result_type, scalar_zero,
|
||||
ArrayRef<Value *>({input_handle, expanded_item}));
|
||||
ArrayRef<Value>({input_handle, expanded_item}));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
@ -429,14 +428,14 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value *> operands,
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
TF::TensorListResizeOp resize_op = cast<TF::TensorListResizeOp>(op);
|
||||
Value *input_handle = operands[0];
|
||||
Value *size = operands[1];
|
||||
Value input_handle = operands[0];
|
||||
Value size = operands[1];
|
||||
|
||||
Location loc = resize_op.getLoc();
|
||||
Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
|
||||
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
|
||||
|
||||
// Compute the input tensorlist's length and store it in `input_size`.
|
||||
IntegerType shape_dtype = rewriter.getIntegerType(32);
|
||||
@ -446,7 +445,7 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
// Infer result type of this op based on TF's shape inference result.
|
||||
Type elem_type = getElementTypeOrSelf(input_handle);
|
||||
auto handle_dtype =
|
||||
getElementTypeOrSelf(resize_op.output_handle()->getType())
|
||||
getElementTypeOrSelf(resize_op.output_handle().getType())
|
||||
.cast<TF::VariantType>();
|
||||
Type result_type =
|
||||
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
|
||||
@ -463,8 +462,8 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
auto input_shape = rewriter.create<TF::ShapeOp>(
|
||||
loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
|
||||
|
||||
Type branch_args_type[] = {input_handle->getType(), input_shape.getType(),
|
||||
size_diff.getType(), size->getType()};
|
||||
Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
|
||||
size_diff.getType(), size.getType()};
|
||||
Type branch_result_type[] = {result_type};
|
||||
auto func_type = FunctionType::get(branch_args_type, branch_result_type,
|
||||
rewriter.getContext());
|
||||
@ -491,7 +490,7 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
rewriter.replaceOpWithNewOp<TF::IfOp>(
|
||||
op, result_type, if_cond,
|
||||
/*input=*/
|
||||
ArrayRef<Value *>({input_handle, input_shape, size_diff, size}),
|
||||
ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
|
||||
/*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op),
|
||||
/*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
|
||||
/*output_shapes=*/rewriter.getStrArrayAttr({"{}"}),
|
||||
@ -517,14 +516,14 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
|
||||
Location loc = resize_op.getLoc();
|
||||
// Get the element shape by slicing from index 1 in the input shape.
|
||||
Value *slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
|
||||
Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
|
||||
Value *slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
|
||||
Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
|
||||
Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
|
||||
Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
|
||||
auto elem_shape = rewriter->create<TF::SliceOp>(
|
||||
loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
|
||||
slice_size);
|
||||
auto extended_part = rewriter->create<TF::TensorListReserveOp>(
|
||||
loc, resize_op.output_handle()->getType(), elem_shape, size_diff);
|
||||
loc, resize_op.output_handle().getType(), elem_shape, size_diff);
|
||||
// `ConcatOp` expects non-variant-typed input. Insert a
|
||||
// `TensorListStackOp` here to convert type from variant to non-variant.
|
||||
// Note that we are using the same `result_type` for both the
|
||||
@ -536,8 +535,8 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
/*num_elements=*/rewriter->getI32IntegerAttr(-1));
|
||||
auto concat_op = rewriter->create<TF::ConcatOp>(
|
||||
loc, result_type, scalar_zero,
|
||||
ArrayRef<Value *>({input, stacked_extended_part}));
|
||||
rewriter->create<ReturnOp>(loc, ArrayRef<Value *>({concat_op}));
|
||||
ArrayRef<Value>({input, stacked_extended_part}));
|
||||
rewriter->create<ReturnOp>(loc, ArrayRef<Value>({concat_op}));
|
||||
}
|
||||
|
||||
void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type,
|
||||
@ -550,8 +549,8 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
Block *block = branch_func.addEntryBlock();
|
||||
rewriter->setInsertionPointToStart(block);
|
||||
|
||||
Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
|
||||
Value *vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
|
||||
Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
|
||||
Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
|
||||
auto input = block->getArgument(0);
|
||||
auto size = block->getArgument(3);
|
||||
|
||||
@ -566,7 +565,7 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
/*start_index=*/scalar_zero, /*size=*/size,
|
||||
/*item_rank=*/partial_position_shape,
|
||||
/*result_type=*/result_type, rewriter);
|
||||
rewriter->create<ReturnOp>(loc, ArrayRef<Value *>({slice_op}));
|
||||
rewriter->create<ReturnOp>(loc, ArrayRef<Value>({slice_op}));
|
||||
}
|
||||
};
|
||||
|
||||
@ -576,11 +575,11 @@ struct ConvertTensorListGetItem : public ConversionPattern {
|
||||
context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListGetItemOp>(operation);
|
||||
Value *input = operands[0];
|
||||
Value *index = operands[1];
|
||||
Value input = operands[0];
|
||||
Value index = operands[1];
|
||||
rewriter.replaceOpWithNewOp<TF::GatherOp>(
|
||||
operation, op.getType(), input, index, rewriter.getBoolAttr(true));
|
||||
return matchSuccess();
|
||||
@ -593,11 +592,11 @@ struct ConvertTensorListLength : public ConversionPattern {
|
||||
context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListLengthOp>(operation);
|
||||
Location loc = op.getLoc();
|
||||
Value *input_handle = operands[0];
|
||||
Value input_handle = operands[0];
|
||||
|
||||
BoolAttr true_attr = rewriter.getBoolAttr(true);
|
||||
auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
|
||||
@ -615,19 +614,19 @@ struct ConvertTensorListStack : public ConversionPattern {
|
||||
context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListStackOp>(operation);
|
||||
Location loc = op.getLoc();
|
||||
Value *input = operands[0];
|
||||
Value *element_shape = operands[1];
|
||||
Value input = operands[0];
|
||||
Value element_shape = operands[1];
|
||||
|
||||
// If the `element_shape` is a known constant (which is defined when calling
|
||||
// `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a
|
||||
// trivial Reshape op (that doesn't actually change the input's shape) and
|
||||
// also populate the shape info to the op result. The shape of the
|
||||
// tensorlist is inferred from `num_elements` and `element_shape`.
|
||||
auto ranked_type = element_shape->getType().dyn_cast<RankedTensorType>();
|
||||
auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
|
||||
DenseIntElementsAttr dense_elem_attr;
|
||||
if ((ranked_type && ranked_type.getRank() == 0) ||
|
||||
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
||||
@ -655,11 +654,11 @@ struct ConvertIdentity : public ConversionPattern {
|
||||
: ConversionPattern(TF::IdentityOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::IdentityOp>(operation);
|
||||
Value *input = operands[0];
|
||||
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input->getType(), operands,
|
||||
Value input = operands[0];
|
||||
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
|
||||
op.getAttrs());
|
||||
return matchSuccess();
|
||||
}
|
||||
@ -687,7 +686,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
Type arg_type = func_type.getInput(i);
|
||||
if (getElementTypeOrSelf(arg_type).isa<TF::VariantType>()) {
|
||||
arg_type = UnrankedTensorType::get(
|
||||
getElementTypeOrSelf(op.getOperand(i)->getType()));
|
||||
getElementTypeOrSelf(op.getOperand(i).getType()));
|
||||
}
|
||||
updated_argument_types.push_back(arg_type);
|
||||
}
|
||||
@ -703,7 +702,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
// from the corresponding input operand. This is correct because while
|
||||
// body's inputs and results have the same type.
|
||||
result_type = UnrankedTensorType::get(
|
||||
getElementTypeOrSelf(op.getOperand(i)->getType()));
|
||||
getElementTypeOrSelf(op.getOperand(i).getType()));
|
||||
}
|
||||
updated_result_types.push_back(result_type);
|
||||
}
|
||||
@ -717,7 +716,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
// Change the argument type for the first block.
|
||||
Block &body_first_bb = func.front();
|
||||
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
|
||||
body_first_bb.getArgument(i)->setType(updated_argument_types[i]);
|
||||
body_first_bb.getArgument(i).setType(updated_argument_types[i]);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
@ -728,19 +727,19 @@ struct ConvertWhile : public ConversionPattern {
|
||||
: ConversionPattern(TF::WhileOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::WhileOp>(operation);
|
||||
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
result_types.reserve(op.getNumOperands());
|
||||
for (int i = 0, e = operands.size(); i != e; ++i) {
|
||||
Type result_ty = op.getResult(i)->getType();
|
||||
Type result_ty = op.getResult(i).getType();
|
||||
|
||||
// If we notice the result type is a DT_VARIANT, we change the
|
||||
// corresponding result type to unranked tensor type.
|
||||
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
|
||||
Type element_ty = getElementTypeOrSelf(operands[i]->getType());
|
||||
Type element_ty = getElementTypeOrSelf(operands[i].getType());
|
||||
result_ty = UnrankedTensorType::get(element_ty);
|
||||
}
|
||||
result_types.push_back(result_ty);
|
||||
|
@ -30,14 +30,14 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
@ -50,16 +50,16 @@ namespace TFL {
|
||||
// The actual Optimize Pass.
|
||||
namespace {
|
||||
|
||||
bool L2NormalizeReduceAxis(Value *sq_op, DenseElementsAttr axis) {
|
||||
if (sq_op->getType().cast<ShapedType>().getRank() - 1 ==
|
||||
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
|
||||
if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
|
||||
*axis.getValues<int>().begin() ||
|
||||
*axis.getValues<int>().begin() == -1) {
|
||||
return true;
|
||||
}
|
||||
if (sq_op->getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
|
||||
if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
|
||||
return false;
|
||||
}
|
||||
auto shape = sq_op->getType().cast<ShapedType>();
|
||||
auto shape = sq_op.getType().cast<ShapedType>();
|
||||
SmallVector<int, 4> elems{axis.getValues<int>().begin(),
|
||||
axis.getValues<int>().end()};
|
||||
for (int i = 0; i < shape.getRank(); ++i) {
|
||||
@ -142,8 +142,8 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
|
||||
|
||||
// Returns shape of a ranked tensor.
|
||||
// Precondition: output_val's is ranked tensor.
|
||||
DenseElementsAttr GetShape(Value *output_val) {
|
||||
auto output_type = output_val->getType().cast<RankedTensorType>();
|
||||
DenseElementsAttr GetShape(Value output_val) {
|
||||
auto output_type = output_val.getType().cast<RankedTensorType>();
|
||||
auto shape_vector = output_type.getShape();
|
||||
std::vector<int32_t> shape(shape_vector.size());
|
||||
for (int i = 0; i < shape_vector.size(); ++i) {
|
||||
@ -152,7 +152,7 @@ DenseElementsAttr GetShape(Value *output_val) {
|
||||
return mlir::DenseElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int>(shape.size())},
|
||||
mlir::IntegerType::get(32, output_val->getContext())),
|
||||
mlir::IntegerType::get(32, output_val.getContext())),
|
||||
llvm::makeArrayRef(shape));
|
||||
}
|
||||
|
||||
@ -167,19 +167,19 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Add.
|
||||
DenseElementsAttr added_value;
|
||||
Value *constant_val = add_op.rhs();
|
||||
Value constant_val = add_op.rhs();
|
||||
if (!matchPattern(constant_val, m_Constant(&added_value)))
|
||||
return matchFailure();
|
||||
|
||||
// Fully Connected.
|
||||
auto fc_op =
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs()->getDefiningOp());
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
|
||||
if (!fc_op) return matchFailure();
|
||||
|
||||
Value *filter = fc_op.filter();
|
||||
Value *bias = fc_op.bias();
|
||||
Value filter = fc_op.filter();
|
||||
Value bias = fc_op.bias();
|
||||
ElementsAttr bias_value;
|
||||
const bool is_none_bias = bias->getType().isa<NoneType>();
|
||||
const bool is_none_bias = bias.getType().isa<NoneType>();
|
||||
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
|
||||
return matchFailure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
||||
@ -213,7 +213,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *input = relu_op.getOperand()->getDefiningOp();
|
||||
Operation *input = relu_op.getOperand().getDefiningOp();
|
||||
if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure();
|
||||
auto fully_connected_op = cast<FullyConnectedOp>(input);
|
||||
if (fully_connected_op.fused_activation_function() != "NONE")
|
||||
@ -242,18 +242,18 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Mul.
|
||||
DenseElementsAttr cst;
|
||||
Value *constant_val = mul_op.rhs();
|
||||
Value constant_val = mul_op.rhs();
|
||||
if (!matchPattern(constant_val, m_Constant(&cst))) return matchFailure();
|
||||
|
||||
// Fully Connected.
|
||||
auto fc_op =
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp());
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
|
||||
if (!fc_op) return matchFailure();
|
||||
Value *filter = fc_op.filter();
|
||||
Value *bias = fc_op.bias();
|
||||
Value filter = fc_op.filter();
|
||||
Value bias = fc_op.bias();
|
||||
ElementsAttr cst_tmp;
|
||||
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
|
||||
if (!bias->getType().isa<NoneType>() &&
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&cst_tmp)))
|
||||
return matchFailure();
|
||||
if (fc_op.fused_activation_function().equals("None")) return matchFailure();
|
||||
@ -261,8 +261,8 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
// Broadcast the constant operand of Mul if it isn't compatible to the
|
||||
// filter input. We only support broadcasting the operand along the depth
|
||||
// dimension, when the operand's depth is 1.
|
||||
Value *new_const_val = constant_val;
|
||||
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) {
|
||||
Value new_const_val = constant_val;
|
||||
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) {
|
||||
auto original_shape = cst.getType().getShape();
|
||||
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
|
||||
original_shape.end());
|
||||
@ -270,7 +270,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
auto new_cst = cst.reshape(RankedTensorType::get(
|
||||
normalized_shape, cst.getType().getElementType()));
|
||||
Type new_type = new_cst.getType();
|
||||
if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) {
|
||||
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
|
||||
return matchFailure();
|
||||
}
|
||||
auto new_op =
|
||||
@ -285,7 +285,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
auto new_filter =
|
||||
rewriter.create<TF::MulOp>(loc, filter, new_const_val).z();
|
||||
// If bias isn't None, it needs to be multiplied as well.
|
||||
if (!bias->getType().isa<NoneType>()) {
|
||||
if (!bias.getType().isa<NoneType>()) {
|
||||
bias = rewriter.create<TF::MulOp>(loc, bias, constant_val).z();
|
||||
}
|
||||
|
||||
@ -311,7 +311,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
PatternMatchResult matchAndRewrite(AffineOpType fc_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Binary op.
|
||||
Operation *binary_op = fc_op.input()->getDefiningOp();
|
||||
Operation *binary_op = fc_op.input().getDefiningOp();
|
||||
if (!binary_op || binary_op->getNumOperands() != 2)
|
||||
return this->matchFailure();
|
||||
// We only handle the cases the RHS is a scalar.
|
||||
@ -325,20 +325,20 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
APFloat cst_value = *cst.float_value_begin();
|
||||
|
||||
// Affine op.
|
||||
Value *filter = fc_op.filter();
|
||||
Value *bias = fc_op.bias();
|
||||
Value filter = fc_op.filter();
|
||||
Value bias = fc_op.bias();
|
||||
DenseFPElementsAttr filter_cst, bias_cst;
|
||||
if (!matchPattern(filter, m_Constant(&filter_cst))) {
|
||||
// The filter maybe quantized, then we should set it to the real constant.
|
||||
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter->getDefiningOp());
|
||||
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
|
||||
if (!dq) return this->matchFailure();
|
||||
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input()->getDefiningOp());
|
||||
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
|
||||
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
|
||||
return this->matchFailure();
|
||||
}
|
||||
filter = q.input();
|
||||
}
|
||||
if (!bias->getType().isa<NoneType>() &&
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&bias_cst)))
|
||||
return this->matchFailure();
|
||||
ShapedType filter_type = filter_cst.getType();
|
||||
@ -362,7 +362,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
// The new bias should be a 1-D tensor with length equals to the bias
|
||||
// dimension of the weight.
|
||||
SmallVector<APFloat, 4> new_bias_values;
|
||||
if (bias->getType().isa<NoneType>()) { // none bias, a list of zeros
|
||||
if (bias.getType().isa<NoneType>()) { // none bias, a list of zeros
|
||||
new_bias_values.resize(bias_size, APFloat(0.0));
|
||||
} else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
|
||||
new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
|
||||
@ -401,12 +401,12 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
// We recreate the constant op in case it is shared by the other ops. This
|
||||
// might increase the model size.
|
||||
auto new_filter_op = rewriter.create<ConstOp>(
|
||||
fc_op.getLoc(), filter->getType(), new_filter);
|
||||
fc_op.getLoc(), filter.getType(), new_filter);
|
||||
fc_op.setOperand(0, binary_op->getOperand(0));
|
||||
if (fc_op.filter() != filter) {
|
||||
// This filter goes through quantize and dequantize ops. Then we just
|
||||
// need to update the weight to the quantize op.
|
||||
filter->replaceAllUsesWith(new_filter_op);
|
||||
filter.replaceAllUsesWith(new_filter_op);
|
||||
} else {
|
||||
// This filter doesn't go through quantize and dequantize ops, Then
|
||||
// we update the weight of the affine op directly.
|
||||
|
@ -17,15 +17,15 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -98,13 +98,13 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
for (int i = 0, e = func.getNumArguments(); i != e; ++i)
|
||||
mapper.map(func.getArgument(i), op.getOperand(i + 1));
|
||||
|
||||
llvm::SmallVector<Value*, 4> updated_results;
|
||||
llvm::SmallVector<Value, 4> updated_results;
|
||||
for (auto& op_to_inline : func.getBody().front()) {
|
||||
// If this is a terminator, identify the values to use to replace the
|
||||
// original If op.
|
||||
if (op_to_inline.isKnownTerminator()) {
|
||||
updated_results.reserve(op_to_inline.getNumOperands());
|
||||
for (Value* operand : op_to_inline.getOperands())
|
||||
for (Value operand : op_to_inline.getOperands())
|
||||
updated_results.push_back(mapper.lookup(operand));
|
||||
break;
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
@ -53,13 +54,15 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in
|
||||
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
|
||||
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
// If we see a binary op (add, sub) op adding a constant value to a convolution
|
||||
// op with constant bias, we can fuse the binary op into the convolution op by
|
||||
// constant folding the bias and the binary op's constant operand. The following
|
||||
// pattern restricts to float constant values for now.
|
||||
multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
|
||||
def : Pat<(binaryOp (TFL_Conv2DOp $input, $filter,
|
||||
def : Pat<(binaryOp (TFL_Conv2DOp:$output $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w),
|
||||
@ -68,8 +71,9 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
|
||||
(binaryOp (ConstantOp $bias),
|
||||
(ConstantOp $value), TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w)>;
|
||||
def : Pat<(binaryOp (TFL_DepthwiseConv2DOp $input, $filter,
|
||||
$padding, $stride_h, $stride_w),
|
||||
[(HasOneUse $output)]>;
|
||||
def : Pat<(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w,
|
||||
@ -81,7 +85,8 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
|
||||
TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier)>;
|
||||
$multiplier),
|
||||
[(HasOneUse $output)]>;
|
||||
}
|
||||
foreach binaryOp = [TFL_AddOp, TFL_SubOp] in
|
||||
defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
|
||||
@ -101,7 +106,7 @@ def ExpandTo4DForDepthwiseConv: NativeCodeCall<
|
||||
// The following pattern restricts to float constant values for now.
|
||||
|
||||
multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
|
||||
def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp $input,
|
||||
def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
|
||||
(ConstantOp F32ElementsAttr:$filter),
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
@ -119,8 +124,9 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value)]>;
|
||||
def : Pat<(BinaryOp (TFL_Conv2DOp $input,
|
||||
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
|
||||
(HasOneUse $output)]>;
|
||||
def : Pat<(BinaryOp (TFL_Conv2DOp:$conv_output $input,
|
||||
(ConstantOp F32ElementsAttr:$filter),
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
@ -135,7 +141,8 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
|
||||
TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w),
|
||||
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value)]>;
|
||||
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
|
||||
(HasOneUse $conv_output)]>;
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in
|
||||
@ -154,7 +161,7 @@ def EqualOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
// Checks if the operand has rank == n
|
||||
class OperandHasRank<int n> : Constraint<
|
||||
CPred<"$0->getType().cast<ShapedType>().getRank() == " # n>>;
|
||||
CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>;
|
||||
|
||||
// Matching HardSwish
|
||||
def : Pat<
|
||||
@ -249,7 +256,7 @@ foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
|
||||
in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
|
||||
|
||||
def AreBroadcastableTypes : Constraint<CPred<
|
||||
"TFL::IsBroadcastableElementsAttrAndType($0->getType(), $1->getType())">>;
|
||||
"TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>;
|
||||
|
||||
// Pattern for skipping Tile if it is mainly for broadcasting and the
|
||||
// Op is already supporting broadcasting.
|
||||
@ -307,3 +314,7 @@ multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||
foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
|
||||
TFL_MulOp, TFL_SubOp] in
|
||||
defm : FusedBinaryActivationFuncOpPat<BinaryOps>;
|
||||
|
||||
// The constant folding in this pass might produce constant in the tf dialect.
|
||||
// This rule is to legalize these constant to the tfl dialect.
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
// This transformation pass applies some clean up steps after quantization.
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -67,33 +67,33 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
|
||||
// In each iteration, a new argument is appended to the end of the list
|
||||
// and the current argument is erased, so here we always process the first
|
||||
// argument in the list.
|
||||
auto* arg = bb.getArgument(0);
|
||||
auto arg = bb.getArgument(0);
|
||||
|
||||
auto remove_quantize_op = [&](QuantizeOp quantize_op) {
|
||||
auto quantize_output = quantize_op.output();
|
||||
auto quantize_type = quantize_output->getType();
|
||||
auto quantize_type = quantize_output.getType();
|
||||
input_types.push_back(quantize_type);
|
||||
auto* new_arg = bb.addArgument(quantize_type);
|
||||
quantize_output->replaceAllUsesWith(new_arg);
|
||||
auto new_arg = bb.addArgument(quantize_type);
|
||||
quantize_output.replaceAllUsesWith(new_arg);
|
||||
quantize_op.erase();
|
||||
arg->dropAllUses();
|
||||
arg.dropAllUses();
|
||||
bb.eraseArgument(0);
|
||||
};
|
||||
|
||||
// This is looking for a pattern: arg -> tfl.quantize
|
||||
if (arg->hasOneUse() && llvm::isa<QuantizeOp>(*arg->user_begin())) {
|
||||
auto quantize_op = llvm::cast<QuantizeOp>(*arg->user_begin());
|
||||
if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
|
||||
auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
|
||||
remove_quantize_op(quantize_op);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Make a copy of current argument and append it to the end of the list if
|
||||
// the pattern isn't found.
|
||||
Type arg_type = arg->getType();
|
||||
Type arg_type = arg.getType();
|
||||
input_types.push_back(arg_type);
|
||||
auto* new_arg = bb.addArgument(arg_type);
|
||||
arg->replaceAllUsesWith(new_arg);
|
||||
arg->dropAllUses();
|
||||
auto new_arg = bb.addArgument(arg_type);
|
||||
arg.replaceAllUsesWith(new_arg);
|
||||
arg.dropAllUses();
|
||||
bb.eraseArgument(0);
|
||||
}
|
||||
|
||||
@ -102,16 +102,16 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
|
||||
llvm::SmallVector<Type, 4> output_types;
|
||||
output_types.reserve(num_return_operands);
|
||||
for (int i = 0; i != num_return_operands; ++i) {
|
||||
auto* returned_value = terminator->getOperand(i);
|
||||
Operation* returned_op = returned_value->getDefiningOp();
|
||||
auto returned_value = terminator->getOperand(i);
|
||||
Operation* returned_op = returned_value.getDefiningOp();
|
||||
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
|
||||
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
|
||||
Value* dequantized_result = dequantize_op.input();
|
||||
output_types.push_back(dequantized_result->getType());
|
||||
Value dequantized_result = dequantize_op.input();
|
||||
output_types.push_back(dequantized_result.getType());
|
||||
terminator->setOperand(i, dequantized_result);
|
||||
returned_op->erase();
|
||||
} else {
|
||||
output_types.push_back(returned_value->getType());
|
||||
output_types.push_back(returned_value.getType());
|
||||
}
|
||||
}
|
||||
auto new_func_type = builder.getFunctionType(input_types, output_types);
|
||||
|
@ -22,19 +22,19 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||
@ -53,8 +53,8 @@ class ConvertEmbeddedLookupFunc {
|
||||
void RewriteFunc() {
|
||||
func_.setAttr(kTFImplements,
|
||||
StringAttr::get("embedding_lookup", func_.getContext()));
|
||||
Value* lookup = func_.getArgument(1);
|
||||
Value* value = func_.getArgument(0);
|
||||
Value lookup = func_.getArgument(1);
|
||||
Value value = func_.getArgument(0);
|
||||
auto output_type = func_.getType().getResult(0);
|
||||
|
||||
OpBuilder builder(func_.getBody());
|
||||
|
@ -135,10 +135,10 @@ def : Pat<(TF_ReshapeOp
|
||||
// Casts result type of $1 to a quantized type by using the quantization
|
||||
// parameters from the type in $0.
|
||||
class UpdateShapeWithAxis<int i> : NativeCodeCall<
|
||||
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1->getType(), " # i # ")">;
|
||||
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">;
|
||||
|
||||
class UsedBy<string op> : Constraint<
|
||||
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0->getUsers().begin())">>;
|
||||
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0.getUsers().begin())">>;
|
||||
|
||||
// When the op is passing-through, the output types of the quantized ops need
|
||||
// to be updated as well. Since the quantize op manages its own type by the
|
||||
|
@ -21,10 +21,10 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
@ -139,7 +139,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
BoolAttr narrow_range = builder.getBoolAttr(false);
|
||||
|
||||
auto add_quantize_op = [&](Location loc, Type input_type, Block* block,
|
||||
Block::iterator insertion_point, Value* arg,
|
||||
Block::iterator insertion_point, Value arg,
|
||||
int i) {
|
||||
if (auto shaped = input_type.dyn_cast<ShapedType>()) {
|
||||
if (shaped.getElementType().isa<FloatType>()) {
|
||||
@ -153,16 +153,16 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
params);
|
||||
auto dq_op =
|
||||
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
|
||||
arg->replaceAllUsesWith(dq_op.output());
|
||||
arg.replaceAllUsesWith(dq_op.output());
|
||||
q_op.setOperand(arg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
|
||||
BlockArgument* arg = func.getArgument(i);
|
||||
auto* arg_block = arg->getOwner();
|
||||
add_quantize_op(arg->getLoc(), arg->getType(), arg_block,
|
||||
BlockArgument arg = func.getArgument(i);
|
||||
auto* arg_block = arg.getOwner();
|
||||
add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
|
||||
std::next(arg_block->begin(), i), arg, i);
|
||||
}
|
||||
|
||||
|
@ -38,17 +38,17 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -115,17 +115,17 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We don't want to insert quantize/dequantize if the quantize op exists.
|
||||
auto res = tf_op.outputs();
|
||||
if (!res->hasOneUse() || isa<QuantizeOp>(*res->user_begin()))
|
||||
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin()))
|
||||
return this->matchFailure();
|
||||
|
||||
// Extract the min/max constant values from the operands. We also consider
|
||||
// a special case that there are tf.Identity ops between the min/max
|
||||
// constants and the tf.FakeQuantWithMinMaxVarsOp.
|
||||
Value *min = tf_op.min(), *max = tf_op.max();
|
||||
Value min = tf_op.min(), max = tf_op.max();
|
||||
DenseFPElementsAttr min_value, max_value;
|
||||
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min->getDefiningOp()))
|
||||
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
|
||||
min = id1.input();
|
||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max->getDefiningOp()))
|
||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
|
||||
max = id2.input();
|
||||
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
|
||||
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
|
||||
@ -133,7 +133,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
int quant_dim = -1;
|
||||
if (PerAxis) {
|
||||
// This is a special case that the quant_dim is the last dimensions.
|
||||
quant_dim = res->getType().template cast<ShapedType>().getRank() - 1;
|
||||
quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
|
||||
}
|
||||
// Use the min/max from the operands and the num_bits and narrow_range
|
||||
// attribute to create the quantization parameter for the new quantize op.
|
||||
@ -150,12 +150,12 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
// Finally, use the quantization parameter to create the quantize and
|
||||
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
|
||||
// and its users.
|
||||
Value *value = tf_op.outputs();
|
||||
Value value = tf_op.outputs();
|
||||
auto quantize = rewriter.create<TFL::QuantizeOp>(
|
||||
tf_op.getLoc(), qtype.getValue(), value, qtype);
|
||||
auto dequantize = rewriter.create<TFL::DequantizeOp>(
|
||||
tf_op.getLoc(), res_type, quantize.output());
|
||||
value->replaceAllUsesWith(dequantize);
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
|
||||
return this->matchSuccess();
|
||||
@ -177,8 +177,8 @@ using PreparePerChannelFakeQuant =
|
||||
//
|
||||
// TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
|
||||
// PatternRewriter &rewriter, Location loc,
|
||||
// Type result_type, Value *input,
|
||||
// Value *filter, Value *bias) const;
|
||||
// Type result_type, Value input,
|
||||
// Value filter, Value bias) const;
|
||||
//
|
||||
// And also the following method for getting the dimension for bias tensor:
|
||||
//
|
||||
@ -240,7 +240,7 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
// that we can extract info from the shape (e.g., for constructing bias
|
||||
// tensor, for setting depth_multiplier attribute, etc.).
|
||||
auto filter_type =
|
||||
tf_op.filter()->getType().template dyn_cast<RankedTensorType>();
|
||||
tf_op.filter().getType().template dyn_cast<RankedTensorType>();
|
||||
if (filter_type && filter_type.getRank() == 4)
|
||||
return matchSuccess(std::move(state));
|
||||
|
||||
@ -262,7 +262,7 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
|
||||
// Get a splat zero tensor with the expected dimension for the bias tensor
|
||||
auto filter = tf_op.filter();
|
||||
auto filter_type = filter->getType().template cast<RankedTensorType>();
|
||||
auto filter_type = filter.getType().template cast<RankedTensorType>();
|
||||
auto elem_type = filter_type.getElementType();
|
||||
auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
|
||||
filter_type.getShape());
|
||||
@ -294,8 +294,8 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
||||
|
||||
TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
|
||||
PatternRewriter &rewriter, Location loc,
|
||||
Type result_type, Value *input, Value *filter,
|
||||
Value *bias) const {
|
||||
Type result_type, Value input, Value filter,
|
||||
Value bias) const {
|
||||
filter = legalizeFilter(rewriter, loc, filter);
|
||||
return rewriter.create<TFL::Conv2DOp>(
|
||||
loc, result_type, input, filter, bias,
|
||||
@ -312,8 +312,8 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
||||
// format HWIO to TFLite Conv2D op filter data format OHWI and return Value
|
||||
// for the converted filter. Requires that filter is verified by the match
|
||||
// method that it is a 4-D RankedTensorType.
|
||||
Value *legalizeFilter(PatternRewriter &rewriter, Location loc,
|
||||
Value *filter) const {
|
||||
Value legalizeFilter(PatternRewriter &rewriter, Location loc,
|
||||
Value filter) const {
|
||||
// Create a constant op for HWIO to OHWI transpose permutation.
|
||||
SmallVector<int, 4> perm = {3, 0, 1, 2};
|
||||
auto perm_type = RankedTensorType::get({static_cast<int>(perm.size())},
|
||||
@ -323,7 +323,7 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
||||
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
|
||||
|
||||
// Create tensor type for the transpose result.
|
||||
auto filter_type = filter->getType().cast<RankedTensorType>();
|
||||
auto filter_type = filter.getType().cast<RankedTensorType>();
|
||||
auto result_shape = functional::map(
|
||||
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
|
||||
perm);
|
||||
@ -349,14 +349,14 @@ class ConvertTFDepthwiseConv2dNative
|
||||
|
||||
TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
|
||||
PatternRewriter &rewriter, Location loc,
|
||||
Type result_type, Value *input,
|
||||
Value *filter, Value *bias) const {
|
||||
Type result_type, Value input,
|
||||
Value filter, Value bias) const {
|
||||
// Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
|
||||
// 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
|
||||
// have a corresponding 'depth_multiplier' attribute; the multiplier is the
|
||||
// fourth dimension in the 4-D filter tensor. We query the multiplier from
|
||||
// tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
|
||||
auto multiplier = filter->getType().cast<RankedTensorType>().getDimSize(3);
|
||||
auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
|
||||
|
||||
filter = legalizeFilter(rewriter, loc, filter);
|
||||
return rewriter.create<TFL::DepthwiseConv2DOp>(
|
||||
@ -378,9 +378,9 @@ class ConvertTFDepthwiseConv2dNative
|
||||
/// filter data format is [1, filter_height, filter_width, out_channels].
|
||||
/// Requires that filter is verified by the match method that it is a 4-D
|
||||
/// RankedTensorType.
|
||||
Value *legalizeFilter(PatternRewriter &rewriter, Location loc,
|
||||
Value *filter) const {
|
||||
auto filter_type = filter->getType().cast<RankedTensorType>();
|
||||
Value legalizeFilter(PatternRewriter &rewriter, Location loc,
|
||||
Value filter) const {
|
||||
auto filter_type = filter.getType().cast<RankedTensorType>();
|
||||
auto filterShape = filter_type.getShape();
|
||||
SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
|
||||
filterShape[2] * filterShape[3]};
|
||||
@ -430,13 +430,13 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
if (new_axis_mask == 0) return matchFailure();
|
||||
|
||||
// Insert a new reshape op.
|
||||
Value *original_input = strided_slice_op.input();
|
||||
Value original_input = strided_slice_op.input();
|
||||
RankedTensorType original_input_type =
|
||||
original_input->getType().cast<RankedTensorType>();
|
||||
original_input.getType().cast<RankedTensorType>();
|
||||
const ArrayRef<int64_t> &original_input_shape =
|
||||
original_input_type.getShape();
|
||||
RankedTensorType begin_type =
|
||||
strided_slice_op.begin()->getType().cast<RankedTensorType>();
|
||||
strided_slice_op.begin().getType().cast<RankedTensorType>();
|
||||
const int dim_size = begin_type.getShape()[0];
|
||||
SmallVector<int64_t, 4> new_shape;
|
||||
int mask = 1;
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user