Merge branch 'master' into master

This commit is contained in:
Rasul Karimov 2020-01-05 22:25:23 +03:00 committed by GitHub
commit fc8a94750e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2162 changed files with 50099 additions and 182852 deletions

View File

@ -238,6 +238,10 @@ build:linux --copt=-w
build:macos --copt=-w
build:windows --copt=/w
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
# _USE_MATH_DEFINES is defined.
build:windows --copt=/D_USE_MATH_DEFINES
# Default paths for TF_SYSTEM_LIBS
build:linux --define=PREFIX=/usr
build:linux --define=LIBDIR=$(PREFIX)/lib
@ -258,9 +262,8 @@ build:windows --host_cxxopt=/std:c++14
# On windows, we still link everything into a single DLL.
build:windows --config=monolithic
# On linux and macos, we dynamically link small amount of kernels
# On linux, we dynamically link small amount of kernels
build:linux --config=dynamic_kernels
build:macos --config=dynamic_kernels
# Make sure to include as little of windows.h as possible
build:windows --copt=-DWIN32_LEAN_AND_MEAN
@ -378,9 +381,9 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:toolchain"
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:cc-toolchain-x64_windows"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"

View File

@ -72,7 +72,7 @@ TensorFlow coding style.
[tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core)
and
[tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python).
TensorFlow has reached version 1 and hence cannot make
TensorFlow has passed version 1.0 and hence cannot make
non-backward-compatible API changes without a major release. Reviewers of
your pull request will comment on any API compatibility issues.
* When you contribute a new feature to TensorFlow, the maintenance burden is

View File

@ -37,18 +37,18 @@ See the [TensorFlow install guide](https://www.tensorflow.org/install) for the
[Docker container](https://www.tensorflow.org/install/docker), and
[build from source](https://www.tensorflow.org/install/source).
To install the current release for CPU-only:
To install the current release, which includes support for
[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu) *(Ubuntu and
Windows)*:
```
$ pip install tensorflow
```
Use the GPU package for
[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu) *(Ubuntu and
Windows)*:
A smaller CPU-only package is also available:
```
$ pip install tensorflow-gpu
$ pip install tensorflow-cpu
```
To update TensorFlow to the latest version, add `--upgrade` flag to the above
@ -56,7 +56,7 @@ commands.
*Nightly binaries are available for testing using the
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) packages on PyPi.*
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
#### *Try your first TensorFlow program*
@ -150,17 +150,3 @@ Learn more about the
## License
[Apache License 2.0](LICENSE)
## Feature Prioritization Survey
The TensorFlow team is working on building/improving features, and understands
that it is very important to prioritize these efforts based on what TF users
need.
The goal of this short, < 5min
[survey](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad), is to help
the TensorFlow team better understand what features to prioritize based on your
feedback. Participation is of course optional.
Take the survey
[HERE](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad).

View File

@ -147,14 +147,16 @@ def write_action_env_to_bazelrc(var_name, var):
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
def run_shell(cmd, allow_non_zero=False):
def run_shell(cmd, allow_non_zero=False, stderr=None):
if stderr is None:
stderr = sys.stdout
if allow_non_zero:
try:
output = subprocess.check_output(cmd)
output = subprocess.check_output(cmd, stderr=stderr)
except subprocess.CalledProcessError as e:
output = e.output
else:
output = subprocess.check_output(cmd)
output = subprocess.check_output(cmd, stderr=stderr)
return output.decode('UTF-8').strip()
@ -169,10 +171,12 @@ def get_python_path(environ_cp, python_bin_path):
if environ_cp.get('PYTHONPATH'):
python_paths = environ_cp.get('PYTHONPATH').split(':')
try:
stderr = open(os.devnull, 'wb')
library_paths = run_shell([
python_bin_path, '-c',
'import site; print("\\n".join(site.getsitepackages()))'
]).split('\n')
],
stderr=stderr).split('\n')
except subprocess.CalledProcessError:
library_paths = [
run_shell([
@ -1179,10 +1183,17 @@ def system_specific_test_config(env):
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
else:
test_and_build_filters.append('-gpu')
write_to_bazelrc('test --test_tag_filters=%s' %
# Disable tests with "v1only" tag in "v2" Bazel config, but not in "v1" config
write_to_bazelrc('test:v1 --test_tag_filters=%s' %
','.join(test_and_build_filters + test_only_filters))
write_to_bazelrc('test --build_tag_filters=%s' %
write_to_bazelrc('test:v1 --build_tag_filters=%s' %
','.join(test_and_build_filters))
write_to_bazelrc(
'test:v2 --test_tag_filters=%s' %
','.join(test_and_build_filters + test_only_filters + ['-v1only']))
write_to_bazelrc('test:v2 --build_tag_filters=%s' %
','.join(test_and_build_filters + ['-v1only']))
def set_system_libs_flag(environ_cp):

View File

@ -860,7 +860,7 @@ gen_api_init_files(
output_files = TENSORFLOW_API_INIT_FILES_V1,
output_package = "tensorflow._api.v1",
root_file_name = "v1.py",
root_init_template = "api_template_v1.__init__.py",
root_init_template = "$(location api_template_v1.__init__.py)",
)
gen_api_init_files(
@ -883,7 +883,7 @@ gen_api_init_files(
output_files = TENSORFLOW_API_INIT_FILES_V2,
output_package = "tensorflow._api.v2",
root_file_name = "v2.py",
root_init_template = "api_template.__init__.py",
root_init_template = "$(location api_template.__init__.py)",
)
py_library(

View File

@ -89,6 +89,7 @@ except ImportError:
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
_compat.enable_v2_behavior()
_major_api_version = 2
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
@ -119,8 +120,14 @@ def _running_from_pip_package():
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
if _running_from_pip_package():
for _s in _site_packages_dirs:
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir)
# Load third party dynamic kernels.
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _fi.file_exists(_plugin_dir):
_ll.load_library(_plugin_dir)

View File

@ -104,6 +104,8 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
_current_module.app.flags = flags # pylint: disable=undefined-variable
setattr(_current_module, "flags", flags)
_major_api_version = 1
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
@ -132,8 +134,14 @@ def _running_from_pip_package():
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
if _running_from_pip_package():
for _s in _site_packages_dirs:
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir)
# Load third party dynamic kernels.
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _fi.file_exists(_plugin_dir):
_ll.load_library(_plugin_dir)

View File

@ -53,6 +53,20 @@ filegroup(
visibility = ["//visibility:public"],
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"c_api_internal.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
tf_cuda_library(
name = "c_api_internal",
hdrs = [

View File

@ -88,6 +88,18 @@ tf_cuda_library(
alwayslink = 1,
)
filegroup(
name = "pywrap_eager_hdrs",
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api_experimental.h"],

View File

@ -464,7 +464,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
&new_remote_device_mgr));
remote_device_mgr = new_remote_device_mgr.get();
} else {
ctx->context->ClearCaches();
ctx->context->ClearCachesAndDefaultExecutor();
// TODO(b/143914772): Potential memory leak if rendezvous has pending
// tensors for removed / replaced workers.
@ -754,7 +754,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
return list;
}
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); }
void TFE_ContextClearCaches(TFE_Context* ctx) {
ctx->context->ClearCachesAndThreadExecutors();
}
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
@ -26,29 +27,22 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
if (!status->status.ok()) {
return nullptr;
}
auto create_or_reset =
[&op_to_reset, &ctx, &name, &types, &raw_device_name, &status](
bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
if (op_to_reset) {
status->status = op_to_reset->Reset(ctx, name, is_function, types,
raw_device_name, inference_ctx);
return op_to_reset;
} else {
TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx);
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}
};
if (op_to_reset && op_to_reset->ctx != ctx) {
status->status = tensorflow::errors::Internal(
"Cannot reset a TFE_Op from another TFE_Context");
return nullptr;
}
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
return create_or_reset(false, new TFE_OpInferenceContext(op_def));
}
if (!ctx->context->FindFunctionByName(name)) {
inference_ctx.reset(new TFE_OpInferenceContext(op_def));
} else if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
@ -58,5 +52,15 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
"registered in the binary running in this process.");
return nullptr;
}
return create_or_reset(true, nullptr);
if (op_to_reset) {
status->status = op_to_reset->Reset(
name, is_function, types, raw_device_name, std::move(inference_ctx));
return op_to_reset;
}
TFE_Op* new_op =
new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
status->status = new_op->operation.SetDeviceName(raw_device_name);
return new_op;
}

View File

@ -125,24 +125,26 @@ struct TFE_OpInferenceContext {
struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* inference_ctx)
: operation(ctx->context, op, is_function, t),
inference_ctx(inference_ctx) {}
std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
: ctx(ctx),
operation(ctx->context, op, is_function, t),
inference_ctx(std::move(inference_ctx)) {}
void Clear() {
operation.Clear();
inference_ctx.reset();
}
tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function,
tensorflow::Status Reset(const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
const char* raw_device_name,
TFE_OpInferenceContext* infer_ctx) {
inference_ctx.reset(infer_ctx);
std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
inference_ctx = std::move(infer_ctx);
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
nullptr);
}
TFE_Context* ctx;
tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
};

View File

@ -233,6 +233,7 @@ cc_library_with_android_deps(
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:lib_experimental",
"//tensorflow/core:protos_all_cc",
],
)

View File

@ -127,6 +127,33 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
target_node_names, outputs, run_metadata);
}
Status ClientSession::Run(
const RunOptions& run_options, const FeedType& inputs,
const std::vector<Output>& fetch_outputs,
const std::vector<Operation>& run_outputs, std::vector<Tensor>* outputs,
RunMetadata* run_metadata,
const thread::ThreadPoolOptions& threadpool_options) const {
std::vector<std::pair<string, Tensor>> feeds;
for (auto const& feed : inputs) {
TF_RETURN_IF_ERROR(feed.second.status);
feeds.emplace_back(feed.first.name(), feed.second.tensor);
}
std::vector<string> output_tensor_names;
output_tensor_names.reserve(fetch_outputs.size());
for (auto const& output : fetch_outputs) {
output_tensor_names.push_back(output.name());
}
std::vector<string> target_node_names;
target_node_names.reserve(run_outputs.size());
for (auto const& output : run_outputs) {
target_node_names.push_back(output.node()->name());
}
TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
return impl()->session_->Run(run_options, feeds, output_tensor_names,
target_node_names, outputs, run_metadata,
threadpool_options);
}
Status ClientSession::MakeCallable(const CallableOptions& callable_options,
CallableHandle* out_handle) {
TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());

View File

@ -93,6 +93,14 @@ class ClientSession {
const std::vector<Operation>& run_outputs,
std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
/// Same as above. Additionally allows user to provide custom threadpool
/// implementation via ThreadPoolOptions.
Status Run(const RunOptions& run_options, const FeedType& inputs,
const std::vector<Output>& fetch_outputs,
const std::vector<Operation>& run_outputs,
std::vector<Tensor>* outputs, RunMetadata* run_metadata,
const thread::ThreadPoolOptions& threadpool_options) const;
/// \brief A handle to a subgraph, created with
/// `ClientSession::MakeCallable()`.
typedef int64 CallableHandle;

View File

@ -112,7 +112,7 @@ TEST(ClientSessionTest, Extend) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({31, 42}, {2}));
}
TEST(ClientSessionTest, MultiThreaded) {
TEST(ClientSessionTest, MultiThreadedWithDefaultThreadpool) {
Scope root = Scope::NewRootScope();
auto a = Add(root, {1, 2}, {3, 4});
auto b = Mul(root, {1, 2}, {3, 4});
@ -138,6 +138,49 @@ TEST(ClientSessionTest, MultiThreaded) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
TEST(ClientSessionTest, MultiThreadedWithCustomThreadpool) {
Scope root = Scope::NewRootScope();
int num_threads = 3;
auto a = Add(root, {1, 2}, {3, 4});
auto b = Mul(root, {1, 2}, {3, 4});
ClientSession session(root);
auto inter_op_threadpool =
absl::make_unique<CustomThreadPoolImpl>(num_threads);
ASSERT_EQ(inter_op_threadpool->GetNumScheduleCalled(), 0);
auto intra_op_threadpool =
absl::make_unique<CustomThreadPoolImpl>(num_threads);
ASSERT_EQ(intra_op_threadpool->GetNumScheduleCalled(), 0);
tensorflow::thread::ThreadPoolOptions threadPoolOptions;
threadPoolOptions.inter_op_threadpool = inter_op_threadpool.get();
threadPoolOptions.intra_op_threadpool = intra_op_threadpool.get();
{
thread::ThreadPool thread_pool(Env::Default(), "pool", 2);
thread_pool.Schedule([&session, a]() {
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {a}, {},
&outputs, nullptr, thread::ThreadPoolOptions()));
test::ExpectTensorEqual<int>(outputs[0],
test::AsTensor<int>({4, 6}, {2}));
});
thread_pool.Schedule([&session, b]() {
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {b}, {},
&outputs, nullptr, thread::ThreadPoolOptions()));
test::ExpectTensorEqual<int>(outputs[0],
test::AsTensor<int>({3, 8}, {2}));
});
}
auto c = Sub(root, b, a);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {c}, {},
&outputs, nullptr, thread::ThreadPoolOptions()));
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
TEST(ClientSessionTest, CallableWithDefaultThreadPool) {
Scope root = Scope::NewRootScope();
auto a = Placeholder(root, DT_INT32);

View File

@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define _USE_MATH_DEFINES
#include <cmath>
#include "tensorflow/cc/ops/array_ops_internal.h"

View File

@ -125,11 +125,11 @@ cc_library(
deps = [
":constants",
"@com_google_absl//absl/container:flat_hash_set",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/util/tensor_bundle",
] + if_not_mobile([
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/util/tensor_bundle",
]),
)

View File

@ -75,8 +75,8 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
"@llvm//:support", # fixdeps: keep
"@llvm//:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:support", # fixdeps: keep
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
],
)
@ -104,11 +104,11 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@llvm//:aarch64_code_gen", # fixdeps: keep
"@llvm//:arm_code_gen", # fixdeps: keep
"@llvm//:powerpc_code_gen", # fixdeps: keep
"@llvm//:target",
"@llvm//:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
],
)
@ -205,9 +205,9 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
"@llvm-project//llvm:core",
"@llvm-project//llvm:support",
"@llvm-project//llvm:target",
],
)

View File

@ -407,6 +407,7 @@ def target_llvm_triple():
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
"//tensorflow:ios": "arm64-none-ios",
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:macos": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",

View File

@ -500,6 +500,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)

View File

@ -22,8 +22,9 @@ limitations under the License.
namespace tensorflow {
// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and
// executes (using XLA) TF function calls marked with "_XlaCompiledKernel".
// Replaces TF function calls marked with `_XlaCompiledKernel` with _XlaCompile
// and _XlaRun nodes (which compile and launch, respectively, the corresponding
// HLO module).
class BuildXlaOpsPass : public GraphOptimizationPass {
public:
// If enable_lazy_compilation is not nullopt then *enable_lazy_compilation

View File

@ -17,6 +17,8 @@ limitations under the License.
namespace tensorflow {
const char* const kXlaMustCompileAttr = "_XlaMustCompile";
const char* const kXlaCompileAttr = "_XlaCompile";
// User-provided through jit_scope APIs. Effective only when auto_jit is OFF.

View File

@ -22,7 +22,16 @@ limitations under the License.
namespace tensorflow {
// Name of attribute used to tag operators for compilation with XLA
// Implies must-compile semantics: either it will be compiled
// with XLA, or an error will be thrown.
extern const char* const kXlaMustCompileAttr; // "_XlaMustCompile"
// Implies auto-clustering: tagged nodes will be clustered and compiled with XLA
// on a best-effort basis.
extern const char* const kXlaCompileAttr; // "_XlaCompile"
// Implies auto-clustering within the given scope.
extern const char* const kXlaScopeAttr; // "_XlaScope"
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"

View File

@ -27,6 +27,15 @@ limitations under the License.
namespace tensorflow {
// EncapsulateSubgraphs pass takes all the nodes with the same cluster ID
// (derived from kXlaClusterAttr=ID (kXlaClusterAttr) attribute), puts them into
// a TF function, and replaces the subgraph in the main graph with a call to
// that TF function annotated with kXlaCompiledKernelAttr (_XlaCompiledKernel).
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
// A rewriting function to apply to each subgraph during encapsulation.
// 'arg_source_tensors' are the tensors corresponding to the arguments in the
// original source graph (*not* 'graph').
@ -100,11 +109,6 @@ extern const char* const kXlaHasReferenceVarsAttr;
// TODO(hpucha): Move the utilities to a more appropriate place.
void SortControlInputs(GraphDef* gdef);
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_

View File

@ -2130,6 +2130,53 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
return Status::OK();
}
Status CopyOutsideCompilationConstNodes(
Graph* g, const string& outside_compilation_attr_name) {
for (Node* n : g->op_nodes()) {
if (!n->IsConstant() ||
!HasNodeAttr(n->def(), outside_compilation_attr_name)) {
continue;
}
std::vector<const Edge*> out_edges(n->out_edges().begin(),
n->out_edges().end());
bool has_non_oc_output = false;
for (const Edge* e : out_edges) {
if (!e->IsControlEdge() &&
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
has_non_oc_output = true;
break;
}
}
if (!has_non_oc_output) {
continue;
}
NodeDef copy_def = n->def();
copy_def.set_name(g->NewName(n->name()));
copy_def.mutable_attr()->erase(outside_compilation_attr_name);
Status s;
Node* copy_node = g->AddNode(copy_def, &s);
TF_RETURN_IF_ERROR(s);
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
g->AddControlEdge(e->src(), copy_node);
}
}
for (const Edge* e : out_edges) {
if (!e->IsControlEdge() &&
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
Node* dst = e->dst();
int dst_input = e->dst_input();
g->RemoveEdge(e);
g->AddEdge(copy_node, 0, dst, dst_input);
}
}
}
return Status::OK();
}
} // namespace
Status RewriteOutsideCompilationSubgraphFn::operator()(
@ -2279,6 +2326,10 @@ Status ExtractOutsideCompilationForFunction(
std::vector<string> outside_compilation_host_graphs;
std::vector<string> shape_inference_graphs_to_rewrite;
if (*has_outside_compilation) {
// Copy outside compilation Const nodes with non outside compilation users.
TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
fbody->graph, outside_compilation_attr_name));
// Find dependencies between outside compilation clusters.
TF_ASSIGN_OR_RETURN(auto cluster_deps,
OutsideCompilationClusterDependencies(

View File

@ -1187,7 +1187,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
}
if (!whitelist.empty() && !whitelist.contains(node->def().op())) {
VLOG(1) << "Rejecting " << node->name()
VLOG(1) << "Rejecting TF operation " << node->def().op()
<< " as it is not listed in --tf_xla_ops_to_cluster.";
continue;
}
@ -1770,9 +1770,10 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
{"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
"Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1",
"Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log",
"Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", "Rsqrt",
"Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", "Square",
"Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Lgamma", "Digamma",
"Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round",
"Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
"Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv",
"Lgamma", "Digamma",
// Binary
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd",
@ -2035,6 +2036,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"XlaDynamicSlice",
"XlaDynamicUpdateSlice",
"XlaEinsum",
"XlaGather",
"XlaIf",
"XlaKeyValueSort",
"XlaPad",
@ -2042,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"XlaReduce",
"XlaReduceWindow",
"XlaReplicaId",
"XlaScatter",
"XlaSelectAndScatter",
"XlaSelfAdjointEig",
"XlaSend",

View File

@ -34,8 +34,9 @@ extern const char* const kXlaClusterAttr;
// compilation by the encapsulate subgraphs pass.
extern const char* const kXlaOutsideCompilationAttr;
// Pass that marks a subset of operators in the graph with attribute
// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass.
// Marks a subset of nodes in the graph which are to be clustered
// with an attribute _XlaCluster=<cluster id> so they are picked up by the
// EncapsulateSubgraphsPass.
class MarkForCompilationPass : public GraphOptimizationPass {
public:
MarkForCompilationPass() = default;

View File

@ -17,7 +17,10 @@ limitations under the License.
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/util/dump_graph.h"
@ -39,7 +42,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
}
Status PropagateShapes(const Graph& graph,
Status PropagateShapes(Graph* graph,
const std::map<int, InferredShape>& arg_shapes,
const std::vector<BackEdgeHelper::BackEdge>& back_edges,
ShapeRefiner* shape_refiner) {
@ -54,7 +57,7 @@ Status PropagateShapes(const Graph& graph,
// shapes.
// TODO(phawkins): handle cyclic graphs.
std::vector<Node*> order;
GetReversePostOrder(graph, &order);
GetReversePostOrder(*graph, &order);
for (Node* n : order) {
// Ignore the status returned by the shape_refiner. We want the best effort
@ -99,6 +102,67 @@ Status PropagateShapes(const Graph& graph,
}
}
// Sometimes we have VariableShape nodes in while loop (after Enter nodes).
// They won't be constant-folded because TensorFlow constant folding does
// not handle Enter nodes (and thus does not handle any nodes after Enter
// nodes). We try to replace such VariableShape nodes with Const nodes here.
if (n->type_string() == "VariableShape") {
shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
shape_inference::ShapeHandle handle =
handle_shapes_and_types->at(0).shape;
TensorShapeProto shape_proto;
context->ShapeHandleToProto(handle, &shape_proto);
if (!shape_proto.unknown_rank()) {
NodeDef const_def;
const_def.set_op("Const");
Node* var_node;
TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
const_def.set_name(
graph->NewName(absl::StrCat("var_shape_", var_node->name())));
DataType dtype = n->output_type(0);
AddNodeAttr("dtype", dtype, &const_def);
TensorProto value;
value.set_dtype(dtype);
value.mutable_tensor_shape()->add_dim()->set_size(
shape_proto.dim_size());
for (const auto& dim : shape_proto.dim()) {
if (dtype == DT_INT32) {
value.add_int_val(dim.size());
} else {
value.add_int64_val(dim.size());
}
}
AddNodeAttr("value", value, &const_def);
for (auto const& attr : n->attrs()) {
if (*attr.first.begin() == '_') {
AddNodeAttr(attr.first, attr.second, &const_def);
}
}
Status s;
Node* const_node = graph->AddNode(const_def, &s);
TF_RETURN_IF_ERROR(s);
graph->AddControlEdge(var_node, const_node);
std::vector<const Edge*> out_edges(n->out_edges().begin(),
n->out_edges().end());
for (const Edge* e : out_edges) {
if (e->IsControlEdge()) {
graph->AddControlEdge(const_node, e->dst());
graph->RemoveEdge(e);
} else {
Node* dst = e->dst();
int dst_input = e->dst_input();
graph->RemoveEdge(e);
graph->AddEdge(const_node, 0, dst, dst_input);
}
}
}
}
}
// Merge node causes a loop so we remove NextIteration->Merge edge before
// performing shape inference. But removing those edges also prevents us
// from inferring output shape for Merge node (we need shapes for all its
@ -196,7 +260,7 @@ Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
// the shape inference is complete.
BackEdgeHelper back_edge;
TF_RETURN_IF_ERROR(back_edge.Remove(graph));
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes,
TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
back_edge.RemovedEdges(), &shape_refiner));
TF_RETURN_IF_ERROR(back_edge.Replace());

View File

@ -191,7 +191,7 @@ class XlaAssignVariableOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
data::IteratorGetNextAsOptionalOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \
data::IteratorGetNextSyncOp); \
data::IteratorGetNextOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \
.Device(DEVICE) \
.HostMemory("string_handle"), \

View File

@ -21,7 +21,7 @@ namespace tensorflow {
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const {
return CanCreateXlaKernel(flr, node_def);
return CanCreateXlaKernel(node_def);
}
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,

View File

@ -95,15 +95,17 @@ AttrValue BoolAttr(bool b) {
TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
FunctionDef fdef = XTimesY();
(*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
Init({fdef});
XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(
flr_, ToNodeDef(R"pb(
NodeDef callsite =
ToNodeDef(R"pb(
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
)pb"),
&kernel_);
)pb");
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
// Note: need to set attribute on the created node.
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
ASSERT_TRUE(status.ok()) << status.ToString();
EXPECT_EQ("XTimesY", kernel_->name());
@ -137,7 +139,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
FunctionDef fdef = XTimesY();
(*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(false);
Init({fdef});
XlaKernelCreator xla_kernel_creator;

View File

@ -23,7 +23,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
@ -68,40 +70,10 @@ class SinglePassSearch {
};
} // namespace
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) {
const FunctionDef* function_def =
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
if (function_def == nullptr) {
// The node def is not calling a function. Individual ops can be
// run directly using on-demand mode, no need to create XlaLaunch
// kernel for them.
return false;
}
// If kXlaCompileAttr is set on the node_def, use its value.
const auto& it = node_def.attr().find(kXlaCompileAttr);
if (it != node_def.attr().end()) {
return it->second.b();
}
// kXlaCompileAttr is not set on node_def, check if it is set on
// FunctionDef.
bool xla_compile = false;
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
node_def, kXlaCompileAttr, &xla_compile);
if (!status.ok() || !xla_compile) {
if (VLOG_IS_ON(3)) {
if (!status.ok()) {
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
<< node_def.op() << ". status=" << status.ToString();
} else {
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
}
}
return false;
}
return true;
bool CanCreateXlaKernel(const NodeDef& node_def) {
// If kXlaMustCompileAttr is set on the node_def, use its value.
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
return it != node_def.attr().end() && it->second.b();
}
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
@ -118,8 +90,11 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
FunctionLibraryRuntime::Handle handle;
// If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
TF_RETURN_IF_ERROR(
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
*fbody = flr->GetFunctionBody(handle);
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
const DataTypeVector& arg_types = (*fbody)->arg_types;
@ -149,7 +124,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
if (!CanCreateXlaKernel(*flr, node_def)) {
if (!CanCreateXlaKernel(node_def)) {
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
}
@ -241,9 +216,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// Create the kernel.
NameAttrList function;
function.set_name(node_def.op());
*(function.mutable_attr()) = node_def.attr();
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(

View File

@ -24,11 +24,9 @@ namespace tensorflow {
class FunctionLibraryRuntime;
class OpKernel;
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
// true if 'node_def' is a call to a compilable function defined in 'flr',
// with the kXlaCompileAttr set.
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def);
// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
// set.
bool CanCreateXlaKernel(const NodeDef& node_def);
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,

View File

@ -6,7 +6,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
package(
default_visibility = [
"//tensorflow/compiler/tf2xla:__subpackages__",
"@local_config_mlir//:friends",
"@llvm-project//mlir:friends",
],
licenses = ["notice"], # Apache 2.0
)
@ -30,8 +30,8 @@ cc_library(
hdrs = ["op_or_arg_name_mapper.h"],
deps = [
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)
@ -43,11 +43,11 @@ cc_library(
":passes",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm//:support",
"@local_config_mlir//:MlirOptLib",
"@local_config_mlir//:Pass",
"@local_config_mlir//:Support",
"@local_config_mlir//test:TestTransforms",
"@llvm-project//llvm:support",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir/test:TestTransforms",
],
)
@ -80,9 +80,9 @@ cc_library(
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",
"@local_config_mlir//:AffineDialectRegistration",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"@llvm-project//mlir:AffineDialectRegistration",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
],
)
@ -92,7 +92,7 @@ cc_library(
hdrs = ["init_mlir.h"],
deps = [
"//tensorflow/core:lib",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)
@ -122,11 +122,11 @@ tf_cc_binary(
"//tensorflow/core:tensorflow",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
"@local_config_mlir//:TranslateClParser",
"@local_config_mlir//:Translation",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TranslateClParser",
"@llvm-project//mlir:Translation",
],
)

View File

@ -1,11 +1,11 @@
# MLIR dialects and utilities for TensorFlow, TensorFlow Lite and XLA.
This module contains the MLIR
([Multi-Level Intermediate Representation](https://github.com/tensorflow/mlir))
([Multi-Level Intermediate Representation](https://mlir.llvm.org))
dialects and utilities for
1. TensorFlow
2. XLA
3. TF Lite
See [MLIR repo](https://github.com/tensorflow/mlir) for complete documentation.
See [MLIR's website](https://mlir.llvm.org) for complete documentation.

View File

@ -10,7 +10,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths")
# Default values used by the test runner.
_default_test_file_exts = ["mlir", ".pbtxt", ".td"]
_default_driver = "@local_config_mlir//:run_lit.sh"
_default_driver = "@llvm-project//mlir:run_lit.sh"
_default_size = "small"
_default_tags = ["no_rocm"]
@ -50,16 +50,16 @@ def _run_lit_test(name, data, size, tags, driver, features):
native.py_test(
name = name,
srcs = ["@llvm//:lit"],
srcs = ["@llvm-project//llvm:lit"],
tags = tags,
args = [
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
] + features,
data = data + [
"//tensorflow/compiler/mlir:litfiles",
"@llvm//:FileCheck",
"@llvm//:count",
"@llvm//:not",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:count",
"@llvm-project//llvm:not",
],
size = size,
main = "lit.py",

View File

@ -1,6 +1,6 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
load(
"@local_config_mlir//:tblgen.bzl",
"//third_party/mlir:tblgen.bzl",
"gentbl",
)
@ -8,13 +8,14 @@ package(
default_visibility = [
# TODO(jpienaar): Make the visibility more restrictive.
":friends",
"//tensorflow/lite/experimental/tf_runtime:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["@local_config_mlir//:subpackages"],
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/...",
"//learning/brain/google/xla/...",
@ -27,7 +28,7 @@ filegroup(
srcs = [
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@local_config_mlir//:OpBaseTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
],
)
@ -47,7 +48,7 @@ gentbl(
"g3doc/tfl_ops.md",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tfl_ops.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
@ -62,11 +63,11 @@ gentbl(
"transforms/generated_prepare_tf.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/prepare_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_optimize_td_files",
],
@ -80,11 +81,11 @@ gentbl(
"transforms/generated_lower_static_tensor_list.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/tensorlist_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
@ -97,11 +98,11 @@ gentbl(
"transforms/generated_legalize_tf.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/legalize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
@ -114,11 +115,12 @@ gentbl(
"transforms/generated_optimize.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/optimize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
@ -130,11 +132,11 @@ gentbl(
"transforms/generated_quantize.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/quantize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
],
)
@ -146,11 +148,11 @@ gentbl(
"transforms/generated_post_quantize.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/post_quantize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
],
)
@ -163,9 +165,9 @@ cc_library(
"utils/validators.h",
],
deps = [
"@local_config_mlir//:Dialect",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
],
)
@ -183,21 +185,21 @@ cc_library(
"transforms/passes.h",
"utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
"@local_config_mlir//:include/mlir/Transforms/InliningUtils.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:Dialect",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
@ -214,10 +216,10 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
@ -231,9 +233,9 @@ cc_library(
],
deps = [
":tensorflow_lite",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
],
)
@ -246,10 +248,10 @@ tf_cc_test(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
@ -290,14 +292,14 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/memory",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:Transforms",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
@ -315,12 +317,12 @@ cc_library(
":tensorflow_lite",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
@ -346,13 +348,13 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
@ -374,7 +376,7 @@ genrule(
"utils/generated_op_quant_spec_getters.inc",
],
cmd = ("$(location //tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen) " +
"-I external/local_config_mlir/include " +
"-I external/llvm-project/mlir/include " +
"-I external/org_tensorflow " +
"$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"),
tools = ["//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen"],
@ -388,7 +390,7 @@ cc_library(
],
deps = [
":tensorflow_lite",
"@local_config_mlir//:IR",
"@llvm-project//mlir:IR",
],
alwayslink = 1,
)
@ -399,9 +401,9 @@ tf_native_cc_binary(
"operator_converter_gen.cc",
],
deps = [
"@llvm//:support",
"@llvm//:tablegen",
"@local_config_mlir//:TableGen",
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//mlir:TableGen",
],
)
@ -434,12 +436,17 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:TransformUtils",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils",
],
)
@ -462,7 +469,7 @@ cc_library(
],
deps = [
"//tensorflow/lite/core/api",
"@local_config_mlir//:IR",
"@llvm-project//mlir:IR",
],
)
@ -507,14 +514,14 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"@local_config_mlir//:StandardDialectRegistration",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:Translation",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:StandardDialectRegistration",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
],
alwayslink = 1,
)
@ -523,7 +530,7 @@ tf_cc_binary(
name = "flatbuffer_translate",
deps = [
":flatbuffer_translate_lib",
"@local_config_mlir//:MlirTranslateMain",
"@llvm-project//mlir:MlirTranslateMain",
],
)
@ -536,7 +543,7 @@ cc_library(
"tf_tfl_translate_cl.h",
],
deps = [
"@llvm//:support",
"@llvm-project//llvm:support",
],
alwayslink = 1,
)
@ -548,7 +555,7 @@ cc_library(
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)
@ -576,9 +583,9 @@ tf_cc_binary(
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
@ -589,16 +596,15 @@ tf_cc_binary(
":flatbuffer_translate_lib",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform/default/build_config:base",
"//tensorflow/lite:framework",
"//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
],
)
@ -621,12 +627,12 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"@local_config_mlir//:Transforms",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Transforms",
],
)
@ -653,15 +659,15 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"@local_config_mlir//:Support",
"@local_config_mlir//:Transforms",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <cstdarg>
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:llvm-project
#include "tensorflow/lite/core/api/error_reporter.h"
namespace tflite {

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include <cctype>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <string>
@ -43,24 +44,24 @@ limitations under the License.
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Translation.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Translation.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
@ -103,12 +104,26 @@ using llvm::cl::opt;
// Commandline flag to enable the control of flatbuffer import.
bool use_external_constant;
// Commandline flag to enable graph pruning.
bool experimental_prune_unreachable_nodes_unconditionally;
// NOLINTNEXTLINE
static opt<bool, true> use_external_constant_flag(
"use-external-constant",
llvm::cl::desc("Use external constant during flatbuffer import"),
llvm::cl::location(use_external_constant), llvm::cl::init(false));
// TODO(b/147111261): After the importer supports generic custom ops, we should
// change the flag to a more lightwise flag, e.g.
// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
// the operations.
// NOLINTNEXTLINE
static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
"experimental-prune-unreachable-nodes-unconditionally",
llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
llvm::cl::init(false));
namespace {
bool IsScalar(const TensorT& tensor) {
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
@ -212,12 +227,12 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
// type, thus none stats op is required and nullptr is retruned.
// If the min max information is invalid, nullptr is returned.
mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
Value* res) {
Value res) {
// If the `tensor` has scale/zero_point, it must have been quantized, then the
// min/max stats is just for comments, so ignore it.
if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
// If the result isn't float and unquantizable, the min/max is ignored.
if (!res->getType()
if (!res.getType()
.cast<mlir::ShapedType>()
.getElementType()
.isa<mlir::FloatType>()) {
@ -255,10 +270,23 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
}
StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
// TODO(krzysd) Support custom ops
// TODO(b/143872630): Support custom ops
if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) {
return errors::Unimplemented("unsupported custom operation: ",
opcode.custom_code);
// Adding some custom op supported on GPU.
const absl::string_view custom_name = opcode.custom_code;
if (custom_name == "MaxPoolingWithArgmax2D") {
return std::string("tfl.max_pooling_with_argmax_2d");
}
if (custom_name == "Convolution2DTransposeBias") {
return std::string("tfl.convolution_2d_transpose_bias");
}
if (custom_name == "MaxUnpooling2D") {
return std::string("tfl.max_unpooling_2d");
}
// Use an unsupported op name instead of throwing an error here in case the
// op is pruned during the import.
return std::string(
llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str());
}
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
return std::string("tf.If");
@ -495,14 +523,21 @@ bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
}
}
// Returns true if this is a custom op.
bool IsCustomOp(const std::string& op_name) {
return op_name == "tfl.max_pooling_with_argmax_2d" ||
op_name == "tfl.max_unpooling_2d" ||
op_name == "tfl.convolution_2d_transpose_bias";
}
// TODO(krzysd) Handle function calls
StatusOr<Operation*> ConvertOp(
const tflite::OperatorT& op, const std::vector<Value*>& vals_map,
Value* optional_arg_marker, const std::vector<std::string>& op_names,
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
Value optional_arg_marker, const std::vector<std::string>& op_names,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
OpBuilder builder) {
llvm::SmallVector<Value*, 4> operands;
llvm::SmallVector<Value, 4> operands;
llvm::SmallVector<mlir::Type, 2> outputTypes;
if (op.outputs.empty()) {
@ -557,7 +592,15 @@ StatusOr<Operation*> ConvertOp(
}
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
if (IsCustomOp(op_name)) {
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
builder, loc, &attrs);
if (!status.ok()) {
return emitError(loc, status.ToString()), status;
}
} else {
mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
}
op_state.addAttributes(attrs);
// Handle the conversion from subgraph index to functions for If and While
@ -619,6 +662,49 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
}
// Given a list of output indices, traverses the subgraph and returns the set of
// ops that are ancestors of the output tensors.
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
// Create a map from tensor index to defining op.
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
for (const auto& op : subgraph.operators) {
for (int32_t output : op->outputs) {
defining_op[output] = op.get();
}
}
std::vector<const tflite::OperatorT*> queue;
for (int32_t output : output_indices) {
if (auto& op = defining_op[output]) {
queue.push_back(op);
} else {
return errors::InvalidArgument("Output tensor doesn't have defining op");
}
}
// Traverse the graph towards inputs.
absl::flat_hash_set<const tflite::OperatorT*> visited;
while (!queue.empty()) {
const tflite::OperatorT* op = queue.back();
queue.pop_back();
if (!visited.insert(op).second) {
// The node has already been visited.
continue;
}
for (int32_t input : op->inputs) {
// Input tensor may not have a defining op in case it is a subgraph input
// or a constant tensor.
if (auto& op = defining_op[input]) {
queue.push_back(op);
}
}
}
return visited;
}
// Build a FuncOp from a tflite SubGraph
// The op_names are a mapping from indexes into the TFLite operators array to
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
@ -635,7 +721,8 @@ StatusOr<FuncOp> ConvertSubgraph(
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
Location base_loc, Builder builder,
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
bool use_external_constant) {
bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
llvm::SmallVector<mlir::Type, 2> ret_types;
llvm::SmallVector<mlir::Type, 4> input_types;
@ -692,19 +779,19 @@ StatusOr<FuncOp> ConvertSubgraph(
auto& body = func.getBody();
OpBuilder op_builder{body};
std::vector<Value*> vals_map(subgraph.tensors.size(), nullptr);
Value* maybe_optional_arg_marker = nullptr;
std::vector<Value> vals_map(subgraph.tensors.size(), nullptr);
Value maybe_optional_arg_marker = nullptr;
// Get or construct MLIR values for each input
for (int i = 0, e = subgraph.inputs.size(); i < e; i++) {
auto input_tensor = subgraph.inputs[i];
const auto& tensor = *subgraph.tensors.at(input_tensor);
auto loc = TensorLoc(tensor, builder, base_loc);
if (nullptr != vals_map[input_tensor]) {
if (vals_map[input_tensor]) {
auto err = errors::FailedPrecondition("duplicate input arguments");
return emitError(loc, err.ToString()), err;
}
Value* input_value = func.getArgument(i);
Value input_value = func.getArgument(i);
// If the `tensor` has min/max and doesn't have scale/zero_point
// information, a stats op is created to use the input_value, then the
@ -731,8 +818,19 @@ StatusOr<FuncOp> ConvertSubgraph(
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
}
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
if (experimental_prune_unreachable_nodes_unconditionally) {
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
PruneSubgraph(subgraph, func_outputs));
}
// Construct MLIR operators from TFLite operators
for (auto& op : subgraph.operators) {
if (experimental_prune_unreachable_nodes_unconditionally &&
!pruned_subgraph_ops.contains(op)) {
continue;
}
for (auto input_num : op->inputs) {
// The operators in a graph are topologically sorted
// and so if no previous operation has produced a tensor
@ -745,7 +843,7 @@ StatusOr<FuncOp> ConvertSubgraph(
builder.getUnitAttr())
.getResult();
}
} else if (nullptr == vals_map.at(input_num)) {
} else if (!vals_map.at(input_num)) {
auto& const_tensor = *subgraph.tensors[input_num];
auto const_loc = TensorLoc(const_tensor, builder, base_loc);
auto op_or_err =
@ -768,7 +866,7 @@ StatusOr<FuncOp> ConvertSubgraph(
? base_loc
: TensorLoc(*subgraph.tensors[op->outputs[0]], builder, base_loc);
// If there's an optional argument, maybe_optional_arg_marker has been set
// to a valid Value*
// to a valid Value
TF_ASSIGN_OR_RETURN(
auto* mlir_op,
ConvertOp(*op, vals_map, maybe_optional_arg_marker, op_names,
@ -791,9 +889,9 @@ StatusOr<FuncOp> ConvertSubgraph(
}
// Construct return values
llvm::SmallVector<Value*, 4> return_operands;
llvm::SmallVector<Value, 4> return_operands;
for (auto index : func_outputs) {
if (nullptr == vals_map.at(index)) {
if (!vals_map.at(index)) {
auto& const_tensor = *subgraph.tensors[index];
auto const_loc = TensorLoc(const_tensor, builder, base_loc);
auto op_or_err =
@ -837,7 +935,8 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
OwningModuleRef tflite::FlatBufferToMlir(
absl::string_view buffer, MLIRContext* context, Location base_loc,
const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant) {
bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
auto model_ptr =
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
if (nullptr == model_ptr) {
@ -892,7 +991,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
// TODO(b/131175224,b/132239787) Support multiple entry points
builder, ordered_output_arrays,
/*is_entry_point=*/e.index() == 0,
/*use_external_constant=*/use_external_constant);
/*use_external_constant=*/use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
if (!func_or_error.ok()) {
return emitError(base_loc, "could not translate function ")
<< subgraph->name,
@ -905,9 +1005,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
return OwningModuleRef(module);
}
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
MLIRContext* context,
bool use_external_constant) {
static OwningModuleRef FlatBufferFileToMlirTrans(
llvm::SourceMgr* source_mgr, MLIRContext* context,
bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
const llvm::MemoryBuffer* input =
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
std::string error;
@ -924,12 +1025,14 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
return tflite::FlatBufferToMlir(
absl::string_view(input->getBufferStart(), input->getBufferSize()),
context, loc, outputs, use_external_constant);
context, loc, outputs, use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
}
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
"tflite-flatbuffer-to-mlir",
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
return FlatBufferFileToMlirTrans(&source_mgr, context,
use_external_constant);
return FlatBufferFileToMlirTrans(
&source_mgr, context, use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
});

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
#include "absl/strings/string_view.h"
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
namespace tflite {
// Converts a TFLite flatbuffer stored in `buffer` to a MLIR module
@ -31,11 +31,14 @@ namespace tflite {
// on failure, and more specific errors will be emitted via the context.
// If `use_external_constant` is true, it will create `tfl.external_const`
// instead of `tfl.const`.
// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that
// are not ancestors of the output nodes will be pruned.
mlir::OwningModuleRef FlatBufferToMlir(
absl::string_view buffer, mlir::MLIRContext* context,
mlir::Location base_loc,
const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant = false);
bool use_external_constant = false,
bool experimental_prune_unreachable_nodes_unconditionally = false);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_

View File

@ -17,15 +17,45 @@ limitations under the License.
#include <vector>
#include "absl/strings/str_cat.h"
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace {
using ::tensorflow::Status;
using ::tensorflow::errors::InvalidArgument;
using ::xla::StatusOr;
StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
mlir::Builder builder,
mlir::Location loc) {
auto padding = tflite::Padding::Padding_VALID;
if (pad_params == TfLitePadding::kTfLitePaddingSame) {
padding = tflite::Padding_SAME;
} else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
padding = tflite::Padding_VALID;
} else {
return InvalidArgument(
absl::StrCat("Invalid padding type", std::to_string(pad_params)));
}
const char* option_name = tflite::EnumNamePadding(padding);
return builder.getStringAttr(option_name);
}
} // namespace
// TODO(jpienaar): This is a placeholder. This should be done in more efficient
// way when part of the translation of module.
static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
@ -212,5 +242,44 @@ static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
return builder.getStringAttr(option_name);
}
Status mlir::CustomOptionsToAttributes(
const std::string& op_name, const std::vector<uint8_t>& custom_options,
mlir::Builder builder, mlir::Location loc,
llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
if (op_name == "tfl.max_pooling_with_argmax_2d" ||
op_name == "tfl.max_unpooling_2d") {
auto* pool_params =
reinterpret_cast<const TfLitePoolParams*>(custom_options.data());
TF_ASSIGN_OR_RETURN(auto padding_attribute,
GetPaddingAttr(pool_params->padding, builder, loc));
attributes->emplace_back(
builder.getNamedAttr("padding", padding_attribute));
attributes->emplace_back(builder.getNamedAttr(
"stride_h", builder.getI32IntegerAttr(pool_params->stride_height)));
attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
attributes->emplace_back(builder.getNamedAttr(
"filter_w", builder.getI32IntegerAttr(pool_params->filter_height)));
attributes->emplace_back(builder.getNamedAttr(
"filter_h", builder.getI32IntegerAttr(pool_params->filter_width)));
return Status::OK();
} else if (op_name == "tfl.convolution_2d_transpose_bias") {
auto* conv_params = reinterpret_cast<const TfLiteTransposeConvParams*>(
custom_options.data());
TF_ASSIGN_OR_RETURN(auto padding_attribute,
GetPaddingAttr(conv_params->padding, builder, loc));
attributes->emplace_back(
builder.getNamedAttr("padding", padding_attribute));
attributes->emplace_back(builder.getNamedAttr(
"stride_h", builder.getI32IntegerAttr(conv_params->stride_height)));
attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(conv_params->stride_width)));
return Status::OK();
}
return InvalidArgument(absl::StrCat("invalid custom op type: ", op_name));
}
// Pull in FlatBuffer writers for TFLite generated using TableGen
#include "tensorflow/compiler/mlir/lite/operator_converters.inc"

View File

@ -26,9 +26,10 @@ limitations under the License.
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {
@ -45,7 +46,7 @@ llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
flatbuffers::FlatBufferBuilder *fbb);
// Populate the array of mlir::NamedAttributes corresponding to the given
// Populates the array of mlir::NamedAttributes corresponding to the given
// tflite::FlatbufferOptionsUnion.
// We use an out parameter per LLVM convention
void BuiltinOptionsToAttributes(
@ -53,6 +54,15 @@ void BuiltinOptionsToAttributes(
// NOLINTNEXTLINE
llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes);
// Populates the array of mlir::NamedAttributes corresponding to the given
// custom_options.
// We use an out parameter per LLVM convention
tensorflow::Status CustomOptionsToAttributes(
const std::string &op_name, const std::vector<uint8_t> &custom_options,
mlir::Builder builder,
// NOLINTNEXTLINE
Location loc, llvm::SmallVectorImpl<mlir::NamedAttribute> *attributes);
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_

View File

@ -41,21 +41,22 @@ limitations under the License.
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Translation.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Translation.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -230,19 +231,19 @@ static bool IsConst(Operation* op) {
}
template <typename T>
static bool HasValidTFLiteType(Value* value, T& error_handler) {
static bool HasValidTFLiteType(Value value, T& error_handler) {
// None type is allowed to represent unspecified operands.
if (value->getType().isa<NoneType>()) return true;
if (value.getType().isa<NoneType>()) return true;
auto type = value->getType().dyn_cast<TensorType>();
auto type = value.getType().dyn_cast<TensorType>();
if (!type) {
if (auto op = value->getDefiningOp()) {
if (auto op = value.getDefiningOp()) {
error_handler.emitError()
<< '\'' << op << "' should produce value of tensor type instead of "
<< value->getType();
<< value.getType();
return false;
}
error_handler.emitError("expected tensor type, got ") << value->getType();
error_handler.emitError("expected tensor type, got ") << value.getType();
return false;
}
@ -279,9 +280,9 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
}
auto& bb = fn.getBlocks().front();
for (auto* arg : bb.getArguments()) {
for (auto arg : bb.getArguments()) {
if (!HasValidTFLiteType(arg, fn))
return fn.emitError("invalid TFLite type: ") << arg->getType(), false;
return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
}
// Verify that all operations except the terminator have exactly one
@ -289,9 +290,9 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
for (auto& inst : bb) {
if (inst.isKnownTerminator()) break;
for (auto* result : inst.getResults()) {
for (auto result : inst.getResults()) {
if (!HasValidTFLiteType(result, inst))
return fn.emitError("invalid TFLite type: ") << result->getType(),
return fn.emitError("invalid TFLite type: ") << result.getType(),
false;
}
}
@ -361,7 +362,7 @@ class Translator {
// Builds TFLite tensor from the given value. `buffer_idx` is index of the
// corresponding buffer. Emits error and returns llvm::None on failure.
Optional<BufferOffset<tflite::Tensor>> BuildTensor(Value* value,
Optional<BufferOffset<tflite::Tensor>> BuildTensor(Value value,
const std::string& name,
unsigned buffer_idx);
@ -419,7 +420,7 @@ class Translator {
bool IsStatefulOperand(mlir::Operation* op, int operand_index);
// Returns a unique name for `val`.
std::string UniqueName(mlir::Value* val);
std::string UniqueName(mlir::Value val);
ModuleOp module_;
@ -449,7 +450,7 @@ class Translator {
std::vector<std::string> failed_custom_ops_;
};
std::string Translator::UniqueName(mlir::Value* val) {
std::string Translator::UniqueName(mlir::Value val) {
return name_mapper_.GetUniqueName(val);
}
@ -502,8 +503,8 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
}
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
Value* value, const std::string& name, unsigned buffer_idx) {
auto type = value->getType().cast<TensorType>();
Value value, const std::string& name, unsigned buffer_idx) {
auto type = value.getType().cast<TensorType>();
// TFLite requires tensor shape only for the inputs and constants.
// However, we output all known shapes for better round-tripping
@ -515,7 +516,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
return mlir::emitError(
value->getLoc(),
value.getLoc(),
"result shape dimensions out of 32 bit int type range");
return mlir::success();
@ -527,7 +528,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
} else if (auto* inst = value->getDefiningOp()) {
} else if (auto* inst = value.getDefiningOp()) {
if (IsConst(inst)) {
// Const op can have a result of dynamic shaped type (e.g. due to constant
// folding), but we can still derive the shape of a constant tensor for
@ -570,7 +571,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
// marked as a stateful. If so, set the tensor's is_variable as true
// This is v1 ref variable semantics in the TFLite runtime.
bool is_variable = false;
for (auto& use : value->getUses()) {
for (auto& use : value.getUses()) {
is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
if (is_variable) {
break;
@ -669,6 +670,16 @@ Translator::CreateFlexBuilderWithNodeAttrs(
case ::tensorflow::AttrValue::kS:
flex_builder->String(key, attr.s());
break;
case ::tensorflow::AttrValue::kType: {
auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type());
if (status_or_tfl_type.ok()) {
flex_builder->Int(key, status_or_tfl_type.ValueOrDie());
} else {
emitWarning(loc, "ignoring unsupported tensorflow type: ")
<< std::to_string(attr.type());
}
break;
}
case ::tensorflow::AttrValue::kI:
flex_builder->Int(key, attr.i());
break;
@ -906,13 +917,13 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
bool has_input_attr = false;
InitializeNamesFromAttribute(fn, &has_input_attr);
std::vector<BufferOffset<tflite::Tensor>> tensors;
llvm::DenseMap<Value*, int> tensor_index_map;
llvm::DenseMap<Value, int> tensor_index_map;
// Builds tensor and buffer for argument or operation result. Returns false
// on failure.
auto build_tensor_and_buffer = [&](Value* value, const std::string& name) {
auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
// NoneType represents optional and may be skipped here.
if (value->getType().isa<NoneType>()) {
if (value.getType().isa<NoneType>()) {
return true;
}
@ -925,7 +936,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// make the Buffer empty apart from setting the buffer_idx=0 in the Tensor.
// This does not seem to affect runtime behavior for RNN/LSTM, but would be
// good for reducing memory footprint.
if (auto* inst = value->getDefiningOp()) {
if (auto* inst = value.getDefiningOp()) {
auto buffer_or = BuildBuffer(inst);
if (!buffer_or) return false;
buffers_.push_back(*buffer_or);
@ -942,7 +953,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for
// other functions.
for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
mlir::BlockArgument* arg = bb.getArgument(i);
mlir::BlockArgument arg = bb.getArgument(i);
std::string name;
if (has_input_attr) name = name_mapper_.GetUniqueName(arg);
if (name.empty()) name = absl::StrCat("arg", i);
@ -964,15 +975,15 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// Fetch operand and result tensor indices.
std::vector<int32_t> operands;
operands.reserve(inst.getNumOperands());
for (auto* operand : inst.getOperands()) {
if (operand->getType().isa<NoneType>())
for (auto operand : inst.getOperands()) {
if (operand.getType().isa<NoneType>())
operands.push_back(kTfLiteOptionalTensor);
else
operands.push_back(tensor_index_map.lookup(operand));
}
std::vector<int32_t> results;
results.reserve(inst.getNumOperands());
for (auto* result : inst.getResults()) {
for (auto result : inst.getResults()) {
results.push_back(tensor_index_map.lookup(result));
}
@ -986,10 +997,10 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// Get input and output tensor indices for the subgraph.
std::vector<int32_t> inputs, outputs;
for (auto* arg : bb.getArguments()) {
for (auto arg : bb.getArguments()) {
inputs.push_back(tensor_index_map[arg]);
}
for (auto* result : bb.getTerminator()->getOperands()) {
for (auto result : bb.getTerminator()->getOperands()) {
outputs.push_back(tensor_index_map[result]);
}

View File

@ -18,14 +18,15 @@ limitations under the License.
#include <string>
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
namespace tflite {
// Translates the given MLIR `module` into a FlatBuffer and stores the
// serialized flatbuffer into the string. This uses OpLocNameMapper to convert
// location of the op to name in flatbuffer.
// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to
// convert location of the op to name in flatbuffer. Returns true if translation
// fails, otherwise returns false.
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops,

View File

@ -25,17 +25,17 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
@ -301,14 +301,14 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
return {};
}
void buildComparisonBinOp(Builder *builder, OperationState &result, Value *lhs,
Value *rhs) {
void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
Value rhs) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!result_type)
emitError(result.location)
<< "non-broadcastable operands: " << lhs->getType() << " and "
<< rhs->getType();
<< "non-broadcastable operands: " << lhs.getType() << " and "
<< rhs.getType();
result.addOperands({lhs, rhs});
// Comparison binary ops always return i1 tensor.
if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
@ -321,15 +321,15 @@ void buildComparisonBinOp(Builder *builder, OperationState &result, Value *lhs,
}
void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
Value *lhs, Value *rhs,
Value lhs, Value rhs,
StringAttr fused_activation_function) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!result_type)
emitError(result.location)
<< "non-broadcastable operands: " << lhs->getType() << " and "
<< rhs->getType();
<< "non-broadcastable operands: " << lhs.getType() << " and "
<< rhs.getType();
result.addOperands({lhs, rhs});
result.addAttribute("fused_activation_function", fused_activation_function);
@ -358,7 +358,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
namespace {
int64_t GetConcatenationOpAxis(ConcatenationOp op) {
auto output_type = op.output()->getType().cast<RankedTensorType>();
auto output_type = op.output().getType().cast<RankedTensorType>();
int64_t axis = op.axis().getSExtValue();
if (axis < 0) axis += output_type.getRank();
return axis;
@ -452,7 +452,7 @@ LogicalResult VerifyConcatenationOpTypes(Operation *op,
}
LogicalResult Verify(ConcatenationOp op) {
auto output_type = op.output()->getType().dyn_cast<RankedTensorType>();
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
// If the output type is unranked, there is nothing else to be verified.
if (!output_type) return success();
@ -462,8 +462,8 @@ LogicalResult Verify(ConcatenationOp op) {
return op.emitOpError("concatenation dimension must be in [-rank, rank)");
SmallVector<TensorType, 4> operand_types;
for (Value *operand : op.values())
operand_types.push_back(operand->getType().cast<TensorType>());
for (Value operand : op.values())
operand_types.push_back(operand.getType().cast<TensorType>());
return VerifyConcatenationOpTypes(op.getOperation(), output_type,
operand_types, axis);
@ -520,7 +520,7 @@ DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
if (fused_activation_function() == "NONE") {
if (auto output_type = output()->getType().dyn_cast<RankedTensorType>()) {
if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
const int64_t axis = GetConcatenationOpAxis(*this);
if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
return ConstFoldConcatenateOpDense(operands, output_type, axis);
@ -528,9 +528,9 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
}
// Remove all empty values.
SmallVector<Value *, 4> non_empty_values;
for (Value *value : this->values()) {
const auto shaped_type = value->getType().cast<ShapedType>();
SmallVector<Value, 4> non_empty_values;
for (Value value : this->values()) {
const auto shaped_type = value.getType().cast<ShapedType>();
if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
continue;
}
@ -559,8 +559,8 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
LogicalResult Verify(FullyConnectedOp op) {
ShapedType input_type = op.input()->getType().cast<ShapedType>();
ShapedType filter_type = op.filter()->getType().cast<ShapedType>();
ShapedType input_type = op.input().getType().cast<ShapedType>();
ShapedType filter_type = op.filter().getType().cast<ShapedType>();
if (filter_type.hasRank() && filter_type.getRank() != 2) {
return op.emitOpError("expect 2d filter, got ") << filter_type;
}
@ -582,7 +582,7 @@ LogicalResult Verify(FullyConnectedOp op) {
// format.
if (op.weights_format() == "DEFAULT") {
ShapedType output_type =
(*op.output().begin())->getType().cast<ShapedType>();
(*op.output().begin()).getType().cast<ShapedType>();
if (!output_type.hasStaticShape()) {
return mlir::success();
}
@ -609,9 +609,9 @@ LogicalResult Verify(FullyConnectedOp op) {
//===----------------------------------------------------------------------===//
static void BuildGatherOp(Builder *builder, OperationState &result,
Value *params, Value *indices, IntegerAttr axis) {
auto params_type = params->getType().cast<TensorType>();
auto indices_type = indices->getType().cast<TensorType>();
Value params, Value indices, IntegerAttr axis) {
auto params_type = params.getType().cast<TensorType>();
auto indices_type = indices.getType().cast<TensorType>();
// If params/indices is unranked, then output is unranked.
if (!params_type.hasRank() || !indices_type.hasRank())
@ -704,8 +704,8 @@ static LogicalResult Verify(PackOp op) {
if (op.getOperation()->getNumOperands() != op.values_count())
return op.emitOpError("input count should match 'values_count' attribute");
Value *operand0 = op.getOperand(0);
auto input_type = operand0->getType().cast<ShapedType>();
Value operand0 = op.getOperand(0);
auto input_type = operand0.getType().cast<ShapedType>();
// Check axis bounds.
if (input_type.hasRank()) {
@ -717,8 +717,8 @@ static LogicalResult Verify(PackOp op) {
// Make sure all inputs have the same shape and element type.
// TODO(rahulsp): Simplify once b/135032064 is fixed.
for (Value *operand : op.getOperands()) {
auto other_type = operand->getType().cast<ShapedType>();
for (Value operand : op.getOperands()) {
auto other_type = operand.getType().cast<ShapedType>();
if (input_type != other_type)
return op.emitOpError("operands should be of the same type. got ")
<< input_type << ", " << other_type;
@ -732,9 +732,9 @@ static LogicalResult Verify(PackOp op) {
//===----------------------------------------------------------------------===//
static LogicalResult Verify(PReluOp op) {
auto input_type = op.input()->getType().cast<ShapedType>();
auto alpha_type = op.alpha()->getType().cast<ShapedType>();
auto output_type = op.output()->getType().cast<ShapedType>();
auto input_type = op.input().getType().cast<ShapedType>();
auto alpha_type = op.alpha().getType().cast<ShapedType>();
auto output_type = op.output().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
if (input_type.getRank() != alpha_type.getRank() + 1) {
@ -783,13 +783,13 @@ struct RemoveAdjacentReshape : public RewritePattern {
PatternMatchResult match(Operation *op) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = thisOp.getOperand(0)->getDefiningOp();
auto prevOp = thisOp.getOperand(0).getDefiningOp();
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0)->getDefiningOp());
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
// Replace
// %1 = "tfl.reshape"(%0, %shape0)
@ -807,7 +807,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
// Remove identity reshape with both static result and input shape.
auto result_type = getType().cast<ShapedType>();
auto input_type = getOperand(0)->getType().cast<ShapedType>();
auto input_type = getOperand(0).getType().cast<ShapedType>();
if (result_type.hasStaticShape() && result_type == input_type) {
return getOperand(0);
}
@ -865,7 +865,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TFL::PackOp pack_op = cast<TFL::PackOp>(op);
Operation *first_input = pack_op.getOperand(0)->getDefiningOp();
Operation *first_input = pack_op.getOperand(0).getDefiningOp();
if (!first_input) return matchFailure();
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
if (!input_unpack_op) return matchFailure();
@ -880,8 +880,8 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
return matchFailure();
for (auto input_output :
llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
Value *pack_input = std::get<0>(input_output);
Value *unpack_output = std::get<1>(input_output);
Value pack_input = std::get<0>(input_output);
Value unpack_output = std::get<1>(input_output);
// Make sure the ordering is the same for the pack op & unpack op.
if (pack_input != unpack_output) return matchFailure();
}
@ -905,9 +905,9 @@ void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SliceOp op) {
auto input_type = op.input()->getType().cast<ShapedType>();
auto begin_type = op.begin()->getType().cast<ShapedType>();
auto size_type = op.size()->getType().cast<ShapedType>();
auto input_type = op.input().getType().cast<ShapedType>();
auto begin_type = op.begin().getType().cast<ShapedType>();
auto size_type = op.size().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
size_type.hasStaticShape()) {
if (input_type.getRank() != begin_type.getNumElements()) {
@ -984,8 +984,8 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
// TopKOp
//===----------------------------------------------------------------------===//
static void BuildTopKOp(Builder *builder, OperationState &result, Value *input,
Value *k) {
static void BuildTopKOp(Builder *builder, OperationState &result, Value input,
Value k) {
// Output size is only known if k is constant value. A negative dimension is
// considered dynamic so use -1 here if k is not a constant value.
int const_k = -1;
@ -995,7 +995,7 @@ static void BuildTopKOp(Builder *builder, OperationState &result, Value *input,
// TODO(jpienaar): This should use a helper function.
const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
auto val_type = input->getType().cast<TensorType>();
auto val_type = input.getType().cast<TensorType>();
// If value is unranked, then so is results.
if (!val_type.hasRank())
return TFL::TopKV2Op::build(
@ -1035,7 +1035,7 @@ struct DropFakeQuant : public RewritePattern {
// If all the users of this op have valid "minmax" attributes, it is matched
// and can be removed.
auto fakeQuantOp = cast<FakeQuantOp>(op);
for (auto *operand : fakeQuantOp.getResult()->getUsers())
for (auto *operand : fakeQuantOp.getResult().getUsers())
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
return matchSuccess();
@ -1075,7 +1075,7 @@ static LogicalResult Verify(UnpackOp op) {
// Extracts and returns the signed integer constant in a 0-rank integer tensor
// or 1-element 1-rank integer tensor if 'value' is a constant.
static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value *value) {
static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value value) {
ElementsAttr attr;
if (!matchPattern(value, m_Constant(&attr))) return {};
if (attr.getNumElements() != 1) return {};
@ -1101,8 +1101,8 @@ static LogicalResult VerifySplitOpOutputTypes(
ExpectedOutputTypeGetter get_expected_output_type) {
for (int64_t i = 0; i < num_splits; ++i) {
auto expected_output_type = get_expected_output_type(i);
Value *output = op->getResult(i);
auto output_type = output->getType().dyn_cast<RankedTensorType>();
Value output = op->getResult(i);
auto output_type = output.getType().dyn_cast<RankedTensorType>();
if (!output_type || output_type != expected_output_type)
return op->emitOpError()
<< "output #" << i << " should be " << expected_output_type;
@ -1121,7 +1121,7 @@ static LogicalResult Verify(SplitOp op) {
if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue();
@ -1157,7 +1157,7 @@ static LogicalResult Verify(SplitVOp op) {
if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue();
@ -1177,8 +1177,7 @@ static LogicalResult Verify(SplitVOp op) {
return success();
if (size_splits_attr.getNumElements() != num_splits) {
auto size_splits_type =
op.size_splits()->getType().cast<RankedTensorType>();
auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
RankedTensorType expected_size_splits_type =
RankedTensorType::get({num_splits}, size_splits_type.getElementType());
return op.emitOpError("'size_splits' should be ")
@ -1414,7 +1413,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
}
// Also fold if `input` has a known rank.
auto input_type = input()->getType().cast<ShapedType>();
auto input_type = input().getType().cast<ShapedType>();
// Do not fold if rank is zero because the TFLite converter doesn't
// distinguish between unranked input and scalar input due to b/138865275.
// TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
@ -1438,6 +1437,56 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
return value();
}
//===----------------------------------------------------------------------===//
// SelectV2Op
//===----------------------------------------------------------------------===//
static void BuildSelectV2Op(Builder *builder, OperationState &result,
Value cond, Value x, Value y) {
auto operand_type =
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
if (!operand_type)
emitError(result.location) << "non-broadcastable operands: " << x.getType()
<< " and " << y.getType();
bool has_static_cond_shape = false;
bool has_static_operand_shape = false;
ArrayRef<int64_t> cond_shape;
ArrayRef<int64_t> operand_shape;
if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
if (shaped_type.hasStaticShape()) {
has_static_cond_shape = true;
cond_shape = shaped_type.getShape();
}
}
if (auto shaped_type = operand_type.dyn_cast<ShapedType>()) {
if (shaped_type.hasStaticShape()) {
has_static_operand_shape = true;
operand_shape = shaped_type.getShape();
}
}
SmallVector<int64_t, 4> broadcastedShape;
if (has_static_cond_shape && has_static_operand_shape &&
!OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
broadcastedShape)) {
emitError(result.location) << "non-broadcastable operands: " << operand_type
<< " and " << cond.getType();
}
result.addOperands({cond, x, y});
auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
if (has_static_cond_shape && has_static_operand_shape) {
result.types.push_back(
RankedTensorType::get(broadcastedShape, elementType));
} else {
result.types.push_back(UnrankedTensorType::get(elementType));
}
}
//===----------------------------------------------------------------------===//
// RangeOp
//===----------------------------------------------------------------------===//
@ -1521,9 +1570,8 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TransposeConvOp op) {
ShapedType output_type = op.output()->getType().cast<ShapedType>();
ShapedType output_shape_type =
op.output_shape()->getType().cast<ShapedType>();
ShapedType output_type = op.output().getType().cast<ShapedType>();
ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
if (output_type.getRank() != output_shape_type.getDimSize(0)) {
return op.emitOpError(llvm::formatv(
@ -1629,9 +1677,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
}
static LogicalResult Verify(TransposeOp op) {
auto input_type = op.x()->getType().cast<ShapedType>();
auto perm_type = op.perm()->getType().cast<ShapedType>();
auto output_type = op.y()->getType().cast<ShapedType>();
auto input_type = op.x().getType().cast<ShapedType>();
auto perm_type = op.perm().getType().cast<ShapedType>();
auto output_type = op.y().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
if (perm_type.getNumElements() != input_type.getRank()) {
return op.emitOpError(

View File

@ -18,15 +18,15 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/Traits.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Dialect.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"

View File

@ -135,7 +135,7 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
//===----------------------------------------------------------------------===//
class TFL_OperandIsUnrankedPred<int n> :
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">;
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
// TODO: Some of these could be generalized and/or moved to more general
// location.
@ -144,38 +144,38 @@ class TFL_OperandHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
").getType().cast<ShapedType>().getRank() == " # m>]>>;
// Returns true if the n-th operand is ranked and has rank dim.
class TFL_OperandHasKnownRank<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() == "
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() == "
# dim>]>;
// True if operand n is ranked and has a rank > dim.
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > "
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
# dim>]>;
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
TFL_OperandIsRankedAndHasDimPred<n, dim>,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()"
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
".getShape()[" # dim # " ] == " # size>]>;
// Returns true if the n-th operand has unknown rank or at least rank m.
class TFL_OperandHasAtleastRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
Or<[CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() >= " # m>]>>;
").getType().cast<ShapedType>().getRank() >= " # m>]>>;
class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
CPred<"$_op.getOperand(" # x #
")->getType().cast<ShapedType>().getRank() == "
").getType().cast<ShapedType>().getRank() == "
"$_op.getOperand(" # y #
")->getType().cast<ShapedType>().getShape()[0]">>;
").getType().cast<ShapedType>().getShape()[0]">>;
class TFL_Operand0DOr1ElementTensor<int x> :
PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element",
@ -195,7 +195,7 @@ class TFL_OperandHasRankLessThan<int n, int m> :
PredOpTrait<"operand " # n # " is maximum " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() <= " # m>]>>;
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
// This is a quantization-aware version of TCresVTEtIsSameAsOp
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
@ -224,10 +224,10 @@ def BinaryOpSameElementTypeConstraint :
//===----------------------------------------------------------------------===//
def TFL_BroadcastableBinaryBuilder : OpBuilder<
"Builder *builder, OperationState &result, Value *lhs, Value *rhs",
"Builder *builder, OperationState &result, Value lhs, Value rhs",
[{
auto resultType =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!resultType)
mlir::emitError(result.location, "non-broadcastable operands");
result.addOperands({lhs, rhs});
@ -235,7 +235,7 @@ def TFL_BroadcastableBinaryBuilder : OpBuilder<
}]>;
def TFL_FusedBroadcastableBinaryBuilder : OpBuilder<
"Builder *builder, OperationState &result, Value *lhs, Value *rhs, "
"Builder *builder, OperationState &result, Value lhs, Value rhs, "
"StringAttr fusedActivationFunction",
[{
buildFusedBroadcastableBinOp(
@ -243,7 +243,7 @@ def TFL_FusedBroadcastableBinaryBuilder : OpBuilder<
}]>;
def TFL_ComparisonBinaryBuilder : OpBuilder<
"Builder *builder, OperationState &result, Value *lhs, Value *rhs",
"Builder *builder, OperationState &result, Value lhs, Value rhs",
[{
buildComparisonBinOp(builder, result, lhs, rhs);
}]>;
@ -427,6 +427,33 @@ def TFL_TransposeConvOp:
let verifier = [{ return Verify(*this); }];
}
def TFL_Convolution2DTransposeBiasOp :
Op<TFL_Dialect, "convolution_2d_transpose_bias", [NoSideEffect]> {
let summary = " Transpose convolution with bias operator";
let description = [{
Performs transpose convolution operation on inputs,
with the option of adding a bias.
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the filter weight tensor
`inputs[2]`: optional: the bias tensor
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$filter,
TFL_TensorOfOrNone<[AnyType]>:$bias,
TFL_PaddingAttr:$padding,
I32Attr:$stride_h,
I32Attr:$stride_w
);
let results = (outs AnyTensor:$output);
}
def TFL_AveragePool2DOp:
TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Average_pool_2d operator";
@ -471,7 +498,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType().
return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
@ -500,7 +527,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType().
return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
@ -669,7 +696,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
let builders =
[
OpBuilder<"Builder *builder, OperationState &result, "
"Value *params, Value *indices, IntegerAttr axis",
"Value params, Value indices, IntegerAttr axis",
[{ BuildGatherOp(builder, result, params, indices, axis); }]>
];
@ -932,7 +959,7 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
let builders =
[
OpBuilder<
"Builder *builder, OperationState &result, Value *lhs, Value *rhs",
"Builder *builder, OperationState &result, Value lhs, Value rhs",
[{
buildComparisonBinOp(builder, result, lhs, rhs);
}]>
@ -1427,6 +1454,63 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
let customOption = "Pool2DOptions";
}
def TFL_MaxPoolingWithArgMax2DOp :
Op<TFL_Dialect, "max_pooling_with_argmax_2d", [NoSideEffect]> {
let summary = "Max Pool 2D with argmax op";
let description = [{
Performs max pooling on the input and outputs both max values and indices.
Each index is a flatten index in a sub-array of "filter_w" x "filter_h" size
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
}];
let arguments = (
ins AnyTensor:$input,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_w,
I32Attr:$filter_h
);
let results = (outs
AnyTensor:$value,
AnyTensor:$indices
);
}
def TFL_MaxUnpooling2DOp :
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect]> {
let summary = "Max Unpool 2D";
let description = [{
Performs max unpool operation.
To some extent this is the reverse operation of max pooling:
the elements in the input activation tensor is stored into the position
specified by the input indices.
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the input indices
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$indices,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_w,
I32Attr:$filter_h
);
let results = (outs AnyTensor:$outputs);
}
def TFL_MaximumOp : TFL_Op<"maximum", [
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale,
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
@ -1996,7 +2080,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
let results = (outs AnyTensor:$output);
DerivedTypeAttr out_type = DerivedTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType();
return getResult().getType().cast<TensorType>().getElementType();
}]>;
let hasOptions = 1;
@ -2081,9 +2165,9 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
// TODO(jpienaar): autogenerate this.
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
"Value *condition, Value *x, Value *y",
"Value condition, Value x, Value y",
[{
auto resultType = x->getType();
auto resultType = x.getType();
result.addOperands({condition, x, y});
result.types.push_back(resultType);
}]>];
@ -2091,6 +2175,32 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
let hasOptions = 1;
}
def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> {
let summary = "SelectV2 operator";
let description = [{
Select values of 'x' if the corresponding value of 'condition' is true or
the value of 'y' if false. There are valid condition input sizes:
1. Either the same shape (in which case the select is elementwise), or
2. Broadcastable shapes between 'condition', 'x' and 'y'.
}];
let arguments = (ins
TFL_BoolTensor:$condition,
TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x,
TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y);
let results = (outs AnyTensor:$output);
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
"Value cond, Value x, Value y",
[{
BuildSelectV2Op(builder, result, cond, x, y);
}]>];
let hasOptions = 1;
}
def TFL_SinOp: TFL_Op<"sin", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
let summary = "Sine operator";
@ -2277,7 +2387,7 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
I32Tensor:$indices);
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
"Value *input, Value *k",
"Value input, Value k",
[{ BuildTopKOp(builder, result, input, k); }]>];
let hasOptions = 1;
@ -2333,14 +2443,14 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
}];
let arguments = (ins
TensorOf<[F32, I8, I32, QI8, QUI8]>:$input,
TensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$input,
I32Attr:$num,
I32Attr:$axis
);
let results = (outs
Variadic<TensorOf<[F32, I8, I32, QI8, QUI8]>>:$outputs
Variadic<TensorOf<[F32, I1, I8, I32, QI8, QUI8]>>:$outputs
);
let verifier = [{ return Verify(*this); }];
@ -2707,7 +2817,7 @@ in the unique output `y`. In other words:
);
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
return getResult(1)->getType().cast<TensorType>().getElementType().
return getResult(1).getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;

View File

@ -19,7 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:llvm-project
namespace mlir {
namespace OpTrait {

View File

@ -30,10 +30,10 @@ limitations under the License.
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Parser.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Parser.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
#include "tensorflow/core/platform/init_main.h"

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Attribute.h" // TF:local_config_mlir
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
using llvm::DefInit;
using llvm::dyn_cast;

View File

@ -28,10 +28,10 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:Support",
"@local_config_mlir//:ViewOpGraph",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:ViewOpGraph",
],
)

View File

@ -19,11 +19,11 @@ limitations under the License.
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/Transforms/ViewOpGraph.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
@ -151,10 +151,9 @@ Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
return errors::InvalidArgument("fail to parse extra OpDef");
}
// Make sure the op is not already registered. If registered continue.
const OpRegistrationData* op_reg = nullptr;
auto status =
tensorflow::OpRegistry::Global()->LookUp(opdef.name(), &op_reg);
if (status.ok()) continue;
const OpRegistrationData* op_reg =
tensorflow::OpRegistry::Global()->LookUp(opdef.name());
if (op_reg) continue;
tensorflow::OpRegistry::Global()->Register(
[opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
@ -278,7 +277,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.

View File

@ -13,7 +13,7 @@ package(
package_group(
name = "friends",
includes = ["@local_config_mlir//:subpackages"],
includes = ["//third_party/mlir:subpackages"],
packages = ["//tensorflow/compiler/mlir/..."],
)
@ -26,8 +26,8 @@ filegroup(
name = "quantization_td_files",
srcs = [
"quantization.td",
"@local_config_mlir//:OpBaseTdFiles",
"@local_config_mlir//:QuantizationOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:QuantizationOpsTdFiles",
],
)
@ -53,13 +53,13 @@ cc_library(
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
@ -75,11 +75,11 @@ cc_library(
],
deps = [
"@com_google_absl//absl/memory",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(fengliuai): remove this dependence.
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:lib_proto_parsing",
@ -97,7 +97,7 @@ cc_library(
deps = [
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)
@ -107,8 +107,8 @@ tf_native_cc_binary(
"tools/op_quant_spec_getters_gen.cc",
],
deps = [
"@llvm//:support",
"@llvm//:tablegen",
"@local_config_mlir//:TableGen",
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//mlir:TableGen",
],
)

View File

@ -23,18 +23,18 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/AffineExpr.h" // TF:local_config_mlir
#include "mlir/IR/AffineMap.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
#include "mlir/IR/AffineMap.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
@ -70,16 +70,16 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
const QuantParamsEntry &info);
void InsertStatsOpAtResult(OpBuilder b, Value *res, ElementsAttr layer_stats,
void InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats,
ElementsAttr axis_stats, IntegerAttr axis);
// If the index is out of range, this method returns false. Otherwise it
// returns true if the value is a float tensor.
bool IsQuantizableResult(Operation *op, int index) {
if (index < 0 || index >= op->getNumResults()) return false;
Value *res = op->getResult(index);
return res->getType().isa<ShapedType>() &&
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
Value res = op->getResult(index);
return res.getType().isa<ShapedType>() &&
res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
}
// A method to retrieve the name for the given op.
@ -117,13 +117,13 @@ bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) {
return false;
}
void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value *res,
void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
ElementsAttr layer_stats,
ElementsAttr axis_stats,
IntegerAttr axis) {
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
layer_stats, axis_stats, axis);
res->replaceAllUsesWith(stats_op);
res.replaceAllUsesWith(stats_op);
stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
}

View File

@ -9,7 +9,7 @@ package(
package_group(
name = "friends",
includes = ["@local_config_mlir//:subpackages"],
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/...",
"//tensorflow/lite/...",
@ -36,9 +36,9 @@ cc_library(
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
)
@ -53,6 +53,6 @@ tf_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)

View File

@ -17,11 +17,11 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"

View File

@ -23,17 +23,17 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
@ -146,14 +146,14 @@ class QuantizationDriver {
// Adds all the users of index-th result of op to the work list.
void AddUserToList(Operation *op, int index) {
for (auto *user : op->getResult(index)->getUsers()) {
for (auto *user : op->getResult(index).getUsers()) {
work_list_.push_back(user);
}
}
// Adds the defining op of index-th operand of op to the work list.
void AddOperandToList(Operation *op, int index) {
if (auto *inst = op->getOperand(index)->getDefiningOp()) {
if (auto *inst = op->getOperand(index).getDefiningOp()) {
work_list_.push_back(inst);
}
}
@ -183,20 +183,20 @@ class QuantizationDriver {
// of the op.
void QuantizeOpResult(Operation *op, int index, QuantParams params);
void QuantizeArg(BlockArgument *arg, QuantParams params);
void QuantizeArg(BlockArgument arg, QuantParams params);
// Inserts the Quantize and Dequantize ops to quantize the value and returns
// the Quantize op.
void QuantizeValue(Value *value, QuantParams params, Location loc);
void QuantizeValue(Value value, QuantParams params, Location loc);
// Inserts the Quantize ops for requantizing the index-th result of the op.
void RequantizeOpResult(Operation *op, int index, RequantizeState *state);
void RequantizeArg(BlockArgument *arg, RequantizeState *state);
void RequantizeArg(BlockArgument arg, RequantizeState *state);
// Inserts the Quantize and Dequantize ops to quantize the value and returns
// the Quantize op.
void RequantizeValue(Value *value, RequantizeState *state, Location loc);
void RequantizeValue(Value value, RequantizeState *state, Location loc);
// A heuristic to get the quantization parameter satisfies the same scale
// constraints for the op. Returns an empty option if this quantization
@ -213,7 +213,7 @@ class QuantizationDriver {
return states_[result_states_[{op, index}]];
}
QuantState &GetArgQuantState(BlockArgument *arg) {
QuantState &GetArgQuantState(BlockArgument arg) {
return states_[arg_states_[arg]];
}
@ -227,7 +227,7 @@ class QuantizationDriver {
return rescale_states_[result_states_[{op, index}]];
}
RequantizeState &GetArgRequantizeState(BlockArgument *arg) {
RequantizeState &GetArgRequantizeState(BlockArgument arg) {
return rescale_states_[arg_states_[arg]];
}
@ -235,32 +235,45 @@ class QuantizationDriver {
// `as_result` is true or index-th operand if `as_result` is false. The state
// is immutable if the type is a quantized type. Returns the index of this
// new state in the state vector.
int InitializeState(Operation *op, int index, Value *val, bool as_result);
int InitializeState(Operation *op, int index, Value val, bool as_result);
// Sets the state of an argument. If this value is cached, uses the cached
// result without creating new entry in the state vector. Otherwise, allocate
// a new entry in the state vector.
void InitializeArgState(BlockArgument arg, Value in,
llvm::DenseMap<Value, int> *cache) {
auto cached = cache->insert({in, 0});
if (!cached.second) {
arg_states_[arg] = cached.first->second;
return;
}
QuantParams params =
quant::QuantizedType::getQuantizedElementType(in.getType());
bool immutable = !EmptyParams(params);
int next_state_index = states_.size();
states_.push_back({params, immutable});
arg_states_[arg] = next_state_index;
cached.first->second = next_state_index;
}
// Sets the state of the index-th operand of the op. If this operand is
// cached, uses the cached result without creating new entry in the state
// vector. Otherwise, allocate a new entry in the state vector.
void InitializeOperandState(Operation *op, int index, Value *in,
llvm::DenseMap<Value *, int> *cache,
bool is_argument) {
void InitializeOperandState(Operation *op, int index, Value in,
llvm::DenseMap<Value, int> *cache) {
auto cached = cache->insert({in, 0});
if (!cached.second) {
operand_states_.insert({{op, index}, cached.first->second});
return;
}
cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
if (is_argument) {
auto *arg = llvm::cast<BlockArgument>(in);
arg_states_[arg] = cached.first->second;
args_.push_back(arg);
}
}
// Sets the state of the index-th result of the op. If this result is cached,
// uses the cached result without creating new entry in the state vector.
// Otherwise, allocate a new entry in the state vector.
void InitializeResultState(Operation *op, int index, Value *res,
llvm::DenseMap<Value *, int> *cache) {
void InitializeResultState(Operation *op, int index, Value res,
llvm::DenseMap<Value, int> *cache) {
auto cached = cache->insert({res, 0});
if (!cached.second) {
result_states_.insert({{op, index}, cached.first->second});
@ -279,7 +292,8 @@ class QuantizationDriver {
// rest are weights.
llvm::DenseSet<Operation *> weights_;
// The weights require narrow_range quantization. If the value of this map is
// The weights require narrow_range quantization. This map collects all the
// weight operands defined by the op quant spec. If the value of the entry is
// positive, per-channel quantization is required.
llvm::DenseMap<Operation *, int> optimized_weights_;
@ -300,11 +314,11 @@ class QuantizationDriver {
// results and arguments.
llvm::DenseMap<OpValue, int> operand_states_;
llvm::DenseMap<OpValue, int> result_states_;
llvm::DenseMap<BlockArgument *, int> arg_states_;
llvm::DenseMap<BlockArgument, int> arg_states_;
// This vector is to preserve the arguments order, so the newly inserted
// quantized ops for the arguments are deterministically ordered.
llvm::SmallVector<BlockArgument *, 4> args_;
llvm::SmallVector<BlockArgument, 4> args_;
OpQuantSpecGetter op_quant_spec_getter_;
};
@ -321,10 +335,10 @@ bool QuantizationDriver::IsQuantized(Operation *op) {
return true;
}
int QuantizationDriver::InitializeState(Operation *op, int index, Value *val,
int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
bool as_result) {
QuantParams params =
quant::QuantizedType::getQuantizedElementType(val->getType());
quant::QuantizedType::getQuantizedElementType(val.getType());
bool immutable = !EmptyParams(params);
int next_state_index = states_.size();
states_.push_back({params, immutable});
@ -338,7 +352,7 @@ int QuantizationDriver::InitializeState(Operation *op, int index, Value *val,
bool QuantizationDriver::SetConstantResultParams(Operation *op) {
ElementsAttr attr;
Value *res = op->getResult(0);
Value res = op->getResult(0);
if (!matchPattern(res, m_Constant(&attr))) {
return false;
}
@ -362,7 +376,7 @@ bool QuantizationDriver::SetConstantResultParams(Operation *op) {
} else {
// per-tensor quantization weight
final_type = GetUniformQuantizedTypeForWeight(
attr, /*symmetric=*/is_weight_with_per_channel_support,
attr, /*symmetric=*/is_weight && is_signed_,
/*num_bits=*/8, is_signed_,
/*narrow_range_=*/is_weight);
}
@ -428,18 +442,18 @@ bool QuantizationDriver::SetOperandParams(Operation *op, int index,
void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
QuantParams params) {
builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
Value *original_result = op->getResult(index);
Value original_result = op->getResult(index);
QuantizeValue(original_result, params, op->getLoc());
}
void QuantizationDriver::QuantizeArg(BlockArgument *arg, QuantParams params) {
builder_.setInsertionPointToStart(arg->getOwner());
void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
builder_.setInsertionPointToStart(arg.getOwner());
QuantizeValue(arg, params, builder_.getUnknownLoc());
}
void QuantizationDriver::QuantizeValue(Value *value, QuantParams params,
void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
Location loc) {
Type expressed_type = value->getType();
Type expressed_type = value.getType();
Type new_type = params.castFromExpressedType(expressed_type);
// This value isn't an expressed type (float), skip.
if (!new_type) return;
@ -451,7 +465,7 @@ void QuantizationDriver::QuantizeValue(Value *value, QuantParams params,
quantize.output());
// `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards
value->replaceAllUsesWith(dequantize);
value.replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
}
@ -459,9 +473,9 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
RequantizeState *state) {
if (state->pos == RequantizeState::NO_REQUANTIZE) return;
builder_.setInsertionPointAfter(op);
Value *value = op->getResult(index);
Value value = op->getResult(index);
if (state->pos == RequantizeState::ON_OUTPUT) {
Operation *user = value->getUses().begin().getUser();
Operation *user = value.getUses().begin().getUser();
if (llvm::isa<TFL::QuantizeOp>(user)) {
// The requantize op is inserted between `quantize` and `dequantize` ops.
value = user->getResult(0);
@ -471,31 +485,31 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
RequantizeValue(value, state, op->getLoc());
}
void QuantizationDriver::RequantizeArg(BlockArgument *arg,
void QuantizationDriver::RequantizeArg(BlockArgument arg,
RequantizeState *state) {
Value *value = arg;
builder_.setInsertionPointToStart(arg->getOwner());
if (value->hasOneUse()) {
auto user = value->use_begin().getUser();
Value value = arg;
builder_.setInsertionPointToStart(arg.getOwner());
if (value.hasOneUse()) {
auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output();
builder_.setInsertionPoint(arg->getOwner(), ++Block::iterator(user));
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
}
}
RequantizeValue(value, state, builder_.getUnknownLoc());
}
void QuantizationDriver::RequantizeValue(Value *value, RequantizeState *state,
void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
Location loc) {
Type new_type;
if (state->pos == RequantizeState::ON_INPUT) {
Type expressed_type = value->getType();
Type expressed_type = value.getType();
// The value needs to be requantized. A Quantize op will be created to use
// it as the operand and replace its uses.
new_type = state->params.castFromExpressedType(expressed_type);
} else {
Type expressed_type =
quant::QuantizedType::castToExpressedType(value->getType());
quant::QuantizedType::castToExpressedType(value.getType());
if (!expressed_type) return;
// The value needs to be requantized. A Quantize op will be created to use
@ -508,7 +522,7 @@ void QuantizationDriver::RequantizeValue(Value *value, RequantizeState *state,
TypeAttr type_attr = TypeAttr::get(new_type);
auto requantize_op =
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
value->replaceAllUsesWith(requantize_op);
value.replaceAllUsesWith(requantize_op);
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
}
@ -586,10 +600,10 @@ void QuantizationDriver::PreprocessConstantOps() {
auto type = cst.getType().dyn_cast<ShapedType>();
if (!type || !type.getElementType().isa<FloatType>()) return;
Value *value = cst.getResult();
Value value = cst.getResult();
SmallVector<std::pair<Operation *, int>, 4> bias_users;
bool used_as_weight = false;
for (auto &use : value->getUses()) {
for (auto &use : value.getUses()) {
auto spec = GetQuantSpec(use.getOwner());
auto biases = spec->biases_params;
Operation *user = use.getOwner();
@ -629,7 +643,20 @@ void QuantizationDriver::PreprocessConstantOps() {
}
void QuantizationDriver::SetupAllStates() {
llvm::DenseMap<Value *, int> value_to_state;
llvm::DenseMap<Value, int> value_to_state;
for (auto arg : fn_.getArguments()) {
args_.push_back(arg);
Value value = arg;
// If the argument is quantized, it should only has one user.
if (arg.hasOneUse()) {
auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output();
}
}
InitializeArgState(arg, value, &value_to_state);
}
fn_.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
@ -638,26 +665,24 @@ void QuantizationDriver::SetupAllStates() {
work_list_.push_back(op);
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
auto *operand = op->getOperand(i);
bool is_argument = true;
if (auto *inst = operand->getDefiningOp()) {
auto operand = op->getOperand(i);
if (auto *inst = operand.getDefiningOp()) {
// If the operand comes from a tfl.dequantize op, we use the quantized
// input of this tfl.dequantize op to set the state.
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
operand = dq.input();
}
is_argument = false;
}
InitializeOperandState(op, i, operand, &value_to_state, is_argument);
InitializeOperandState(op, i, operand, &value_to_state);
}
for (int res = 0, e = op->getNumResults(); res != e; ++res) {
auto *result = op->getResult(res);
Value result = op->getResult(res);
// If the result has been quantized, it should only be used by a
// tfl.quantize op. For this case, we uses the quantized result to
// create the state and mark it immutable.
if (result->hasOneUse()) {
auto user = result->use_begin().getUser();
if (result.hasOneUse()) {
auto user = result.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
result = q.output();
}
@ -746,7 +771,7 @@ bool QuantizationDriver::PropagateParams() {
}
void QuantizationDriver::Finalize() {
for (auto *arg : args_) {
for (auto arg : args_) {
auto &state = GetArgQuantState(arg);
auto &requantize = GetArgRequantizeState(arg);
if (state.IsEmpty() ||

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
namespace mlir {
namespace quant {

View File

@ -18,8 +18,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
namespace mlir {
namespace OpTrait {
@ -70,7 +70,7 @@ class FixedResultUniformScale {
QuantizedType GetResultQuantizedType(int index) {
auto op = this->getOperation();
auto result_type =
op->getResult(index)->getType().template cast<TensorType>();
op->getResult(index).getType().template cast<TensorType>();
Builder builder(op->getContext());
IntegerType storage_type = builder.getIntegerType(BitWidth);
const double scale = static_cast<double>(ScaleMantissa) *

View File

@ -21,15 +21,15 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
namespace mlir {
@ -367,7 +367,7 @@ ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
static bool PreferResultScale(Operation* op) {
int float_operands = 0;
for (auto operand : op->getOperands()) {
if (auto operand_type = operand->getType().dyn_cast<ShapedType>()) {
if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
if (operand_type.getElementType().isa<FloatType>()) {
if (float_operands++ > 1) return true;
}
@ -400,22 +400,22 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
quant::StatisticsOp stats_op = all_stats_ops.back();
all_stats_ops.pop_back();
if (auto def = stats_op.arg()->getDefiningOp()) {
if (auto def = stats_op.arg().getDefiningOp()) {
if (IsStatsRedundant(def, op_quant_spec_getter)) {
redundant_stats_ops.insert(stats_op);
}
}
for (auto user : stats_op.getResult()->getUsers()) {
for (auto user : stats_op.getResult().getUsers()) {
// We don't propagate this parameter down if it has multiple operands.
// We want to use the result parameter scales instead.
if (user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
!PreferResultScale(user)) {
for (Value* res : user->getResults()) {
if (res->hasOneUse()) {
for (Value res : user->getResults()) {
if (res.hasOneUse()) {
if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
*res->getUsers().begin())) {
*res.getUsers().begin())) {
// quantization parameters can be propagated to next_stats
redundant_stats_ops.insert(next_stats);
// add next_stats to the work list so propagation can
@ -440,12 +440,12 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
quant::StatisticsOp stats_op = all_stats_ops.back();
all_stats_ops.pop_back();
if (auto def = stats_op.arg()->getDefiningOp()) {
if (auto def = stats_op.arg().getDefiningOp()) {
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
PreferResultScale(def)) {
for (auto input : def->getOperands()) {
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
input->getDefiningOp())) {
input.getDefiningOp())) {
redundant_stats_ops.insert(next_stats);
all_stats_ops.push_back(next_stats);
}
@ -458,7 +458,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
for (auto it : redundant_stats_ops) {
if (!llvm::isa<quant::StatisticsOp>(it)) return true;
auto stats_op = llvm::cast<quant::StatisticsOp>(it);
stats_op.getResult()->replaceAllUsesWith(stats_op.arg());
stats_op.getResult().replaceAllUsesWith(stats_op.arg());
stats_op.erase();
}

View File

@ -23,18 +23,18 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir {
@ -116,7 +116,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
TypeAttr::get(result_type));
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
op.getResult()->replaceAllUsesWith(dq);
op.getResult().replaceAllUsesWith(dq);
q.getOperation()->replaceUsesOfWith(dq, op.arg());
op.erase();
@ -161,8 +161,8 @@ struct QuantizationPattern : public RewritePattern {
if (op->getNumResults() != 1) {
return matchFailure();
}
Value* quantized_value = op->getResult(0);
for (Operation* quantized_op : quantized_value->getUsers()) {
Value quantized_value = op->getResult(0);
for (Operation* quantized_op : quantized_value.getUsers()) {
// If it is requantize op, we shouldn't rewrite this op.
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
return matchFailure();
@ -176,17 +176,17 @@ struct QuantizationPattern : public RewritePattern {
// Collect all the quantized inputs and "clone" the matched op by these
// inputs.
SmallVector<Value*, 4> inputs;
SmallVector<Value, 4> inputs;
inputs.reserve(quantized_op->getNumOperands());
for (auto operand : quantized_op->getOperands()) {
Type operand_type = operand->getType();
Type operand_type = operand.getType();
if (operand_type.isa<NoneType>()) {
inputs.push_back(operand);
continue;
}
auto ele_type = operand->getType().cast<TensorType>().getElementType();
if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) {
auto ele_type = operand.getType().cast<TensorType>().getElementType();
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
inputs.push_back(op_inst.input());
} else if (ele_type.isa<IntegerType>()) {
// If the operand is an integer tensor, then it doesn't require the
@ -201,13 +201,13 @@ struct QuantizationPattern : public RewritePattern {
// Collect all the quantized outputs and replace them by the results of
// the new quantized op.
llvm::SmallDenseMap<Value*, int> outputs_replaced;
llvm::SmallDenseMap<Value, int> outputs_replaced;
SmallVector<Type, 4> output_types;
output_types.reserve(quantized_op->getNumResults());
for (auto enumerated_result :
llvm::enumerate(quantized_op->getResults())) {
Value* result = enumerated_result.value();
Type result_type = result->getType();
Value result = enumerated_result.value();
Type result_type = result.getType();
// Add this to the test coverage once we create test ops with none type
// results.
if (result_type.isa<NoneType>()) {
@ -216,20 +216,20 @@ struct QuantizationPattern : public RewritePattern {
continue;
}
Type result_ele_type =
result->getType().cast<TensorType>().getElementType();
result.getType().cast<TensorType>().getElementType();
// If the user is the Quantize op, it must be the only user.
if (result->hasOneUse() && llvm::isa<Q>(*result->user_begin())) {
auto user = llvm::cast<Q>(*result->user_begin());
if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
auto user = llvm::cast<Q>(*result.user_begin());
outputs_replaced.insert({user.output(), enumerated_result.index()});
output_types.push_back(user.getType());
} else if (result_ele_type.template isa<IntegerType>()) {
// If the result is an integer tensor, then it doesn't require the
// D op in the pattern.
outputs_replaced.insert({result, enumerated_result.index()});
output_types.push_back(result->getType());
output_types.push_back(result.getType());
} else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) {
outputs_replaced.insert({result, enumerated_result.index()});
output_types.push_back(result->getType());
output_types.push_back(result.getType());
} else {
return matchFailure();
}
@ -241,7 +241,7 @@ struct QuantizationPattern : public RewritePattern {
output_types, quantized_op->getAttrs());
Operation* new_op = rewriter.createOperation(new_state);
for (auto output : outputs_replaced) {
output.getFirst()->replaceAllUsesWith(
output.getFirst().replaceAllUsesWith(
new_op->getResult(output.getSecond()));
}
@ -252,7 +252,7 @@ struct QuantizationPattern : public RewritePattern {
// For constant operands, the floating-point constant is duplicated in
// case it is quantized.
for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
auto def = new_op->getOperand(i)->getDefiningOp();
auto def = new_op->getOperand(i).getDefiningOp();
if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
DenseFPElementsAttr attr;
if (!matchPattern(q.input(), m_Constant(&attr))) {
@ -265,7 +265,7 @@ struct QuantizationPattern : public RewritePattern {
for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
if (!quantized_op->getResult(i)
->getType()
.getType()
.cast<ShapedType>()
.getElementType()
.isa<FloatType>()) {
@ -283,13 +283,13 @@ struct QuantizationPattern : public RewritePattern {
// Find the Dequantize/Dequantize users of the new op results, and
// replace the usage. Then all the floating-point ops are connected.
// N.B. the return op will use this floating-point result.
for (auto user : new_op->getResult(i)->getUsers()) {
for (auto user : new_op->getResult(i).getUsers()) {
// Skip the Requantize op, and we know it has a single user.
if (llvm::isa<Q>(user)) {
user = *user->getResult(0)->getUsers().begin();
user = *user->getResult(0).getUsers().begin();
}
if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
dequantize.getResult()->replaceAllUsesWith(
dequantize.getResult().replaceAllUsesWith(
quantized_op->getResult(i));
}
}
@ -316,7 +316,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
PatternMatchResult matchAndRewrite(Q op,
PatternRewriter& rewriter) const override {
Type output_type = op.output()->getType();
Type output_type = op.output().getType();
auto qtype = QType::getQuantizedElementType(output_type);
if (!qtype || qtype.isSigned()) return this->matchFailure();

View File

@ -4,7 +4,7 @@ package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
@ -14,6 +14,6 @@ filegroup(
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm//:FileCheck",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Operator.h" // TF:local_config_mlir
#include "mlir/TableGen/Operator.h" // TF:llvm-project
using llvm::LessRecord;
using llvm::raw_ostream;
@ -36,7 +36,7 @@ using mlir::tblgen::Operator;
// NOLINTNEXTLINE
static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
llvm::Regex acc_uniform_trait_regex{"AccumulatorUniformScale<([0-9]*),"};
llvm::Regex coeff_index_trait_regex{"AffineOpCoefficient<([0-9]*),"};
llvm::Regex coeff_index_trait_regex{"AffineOpCoefficient<(-?[0-9]*),"};
llvm::Regex fixed_uniform_trait_regex{
"FixedResultUniformScale<([0-9]+).*(true|false)>"};
emitSourceFileHeader("Generated Ops Quant Spec Getters", os);

View File

@ -4,7 +4,7 @@ package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
@ -14,6 +14,6 @@ filegroup(
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm//:FileCheck",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -7,10 +7,12 @@ glob_lit_tests(
":debug_info_files",
":test_utilities",
],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = [
"pbtxt",
"py",
# TODO(fengliuai): reenable these tests after the fused loc is
# supported in the diagnostic handler.
# "py",
],
)
@ -31,8 +33,8 @@ filegroup(
":saved_model_error",
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
"@llvm//:FileCheck",
"@llvm//:not",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

View File

@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function
import sys
from absl import app
import tensorflow.compat.v2 as tf

View File

@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function
import sys
from absl import app
import tensorflow.compat.v2 as tf

View File

@ -7,7 +7,7 @@ glob_lit_tests(
":quant_stats_files",
":test_utilities",
],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = [
"pbtxt",
],
@ -20,7 +20,7 @@ filegroup(
data = [
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
"@llvm//:FileCheck",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -38,6 +38,6 @@ versions {
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %0 : tensor<*xi32>
# CHECK-NEXT: }

View File

@ -8,7 +8,7 @@ glob_lit_tests(
":extra_files",
":test_utilities",
],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = [
"mlir",
"cc",
@ -24,7 +24,7 @@ filegroup(
":importer_test_min_max",
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
"@llvm//:FileCheck",
"@llvm-project//llvm:FileCheck",
],
)
@ -51,7 +51,7 @@ tf_native_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)
@ -67,6 +67,6 @@ tf_native_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)

View File

@ -11,6 +11,8 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
// CHECK: %[[EXP:.*]] = "tfl.exp"
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
// tfl.neg should not be pruned
// CHECK: %[[NEG:.*]] = "tfl.neg"
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
return %5 : tensor<4xf32>

View File

@ -0,0 +1,19 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -output-arrays=mul,exp,div --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// Confirm graph pruning.
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
// CHECK: %[[MUL:.*]] = tfl.mul
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
// CHECK: %[[DIV:.*]] = tfl.div
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
// CHECK: %[[EXP:.*]] = "tfl.exp"
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
// tfl.neg should be pruned
// CHECK-NOT: "tfl.neg"
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
return %5 : tensor<4xf32>
}

View File

@ -521,21 +521,30 @@ func @select_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor
// CHECK: return
}
func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
func @select_v2_same_shape(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
return %0: tensor<8xf32>
// CHECK-LABEL: select_v2
// CHECK-LABEL: select_v2_same_shape
// CHECK: "tfl.select"(%arg0, %arg1, %arg2)
// CHECK: return
}
func @select_v2_multidim(%arg0: tensor<8xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
func @select_v2_multidim(%arg0: tensor<3xi1>, %arg1: tensor<8x3xf32>, %arg2: tensor<8x3xf32>) -> tensor<8x3xf32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<8x3xf32>, tensor<8x3xf32>) -> tensor<8x3xf32>
return %0: tensor<8x3xf32>
// CHECK-LABEL: select_v2_multidim
// CHECK: "tfl.select"(%arg0, %arg1, %arg2)
// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2)
// CHECK: return
}
func @select_v2_broadcast(%arg0: tensor<4xi1>, %arg1: tensor<3x4xf32>, %arg2: tensor<8x3x4xf32>) -> tensor<8x3x4xf32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor<3x4xf32>, tensor<8x3x4xf32>) -> tensor<8x3x4xf32>
return %0: tensor<8x3x4xf32>
// CHECK-LABEL: select_v2_broadcast
// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2)
// CHECK: return
}

View File

@ -4,7 +4,7 @@ licenses(["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
@ -15,7 +15,7 @@ filegroup(
data = [
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
"@llvm//:FileCheck",
"@llvm//:not",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

View File

@ -0,0 +1,40 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
// CHECK: {
// CHECK: version: 3,
// CHECK: operator_codes: [ {
// CHECK: builtin_code: CUSTOM,
// CHECK: custom_code: "SomeOperation"
// CHECK: } ],
// CHECK: subgraphs: [ {
// CHECK: tensors: [ {
// CHECK: shape: [ ],
// CHECK: type: INT32,
// CHECK: buffer: 1,
// CHECK: name: "tf.SomeOperation",
// CHECK: quantization: {
// CHECK-EMPTY
// CHECK: }
// CHECK: } ],
// CHECK: inputs: [ ],
// CHECK: outputs: [ 0 ],
// CHECK: operators: [ {
// CHECK: inputs: [ ],
// CHECK: outputs: [ 0 ],
// CHECK: custom_options: [ 100, 116, 121, 112, 101, 0, 1, 7, 1, 1, 1, 2, 4, 2, 36, 1 ]
// CHECK: } ],
// CHECK: name: "main"
// CHECK: } ],
// CHECK: description: "MLIR Converted.",
// CHECK: buffers: [ {
// CHECK-EMPTY
// CHECK: }, {
// CHECK-EMPTY
// CHECK: } ]
// CHECK: }
func @main() -> tensor<*xi32> {
// Tests that the below type attribute is convertible into the corresponding custom option in flatbuffer.
%0 = "tf.SomeOperation"() {dtype = i32 } : () -> tensor<*xi32>
return %0 : tensor<*xi32>
}

View File

@ -518,6 +518,20 @@ func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform<i9:f32
// -----
func @testMaxPoolingWithArgMax2D(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
}
// -----
func @testMaxUnpooling2D(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
return %0 : tensor<1x8x8x128xf32>
}
// -----
// CHECK-LABEL: testLogistic
func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
^bb0(%arg0: tensor<1x2x3x4x5xbf16>):
@ -1942,6 +1956,13 @@ func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %ar
// -----
func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
return %0 : tensor<1x64x84x32xf32>
}
// -----
func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> {
// expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}}
%0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32>

View File

@ -140,22 +140,6 @@ func @fuseAddWithRelu6IntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1:
// CHECK-SAME: fused_activation_function = "RELU6"
}
// CHECK-LABEL: intermOpUsedTwice
func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
%cst = constant dense<1.5> : tensor<16xf32>
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %0, %1 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
// CHECK: %cst = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00,
// CHECK: %cst_0 = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00,
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
// CHECK: %1 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
// CHECK: return %0, %1
}
// CHECK-LABEL: @fuseMulIntoFullyConnected
func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
@ -167,8 +151,8 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
return %1 : tensor<4x2xf32>
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: return %[[RES]] : tensor<4x2xf32>
}
@ -233,8 +217,8 @@ func @fuseMulIntoFullyConnectedBroadcast(%arg0: tensor<1x3xf32>) -> tensor<1x2xf
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x2xf32>, tensor<2xf32>) -> tensor<1x2xf32>
return %1 : tensor<1x2xf32>
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"{{.*}} dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
// CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: return %[[RES]] : tensor<1x2xf32>
}
@ -249,7 +233,7 @@ func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> te
return %1 : tensor<4x2xf32>
// CHECK: %[[CONSTANT:.*]] = "tf.Const"{{.*}} dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
// CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
// CHECK: return %[[RES]] : tensor<4x2xf32>
}
@ -631,3 +615,18 @@ func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor
// CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"}
// CHECK: return %[[RES]]
}
// CHECK-LABEL: NotfuseAddIntoConv2d_MultipleUsers
func @NotfuseAddIntoConv2d_MultipleUsers(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
%cst = constant dense<1.5> : tensor<16xf32>
%cst_1 = constant dense<3.5> : tensor<16xf32>
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
%0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
%2 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %1, %2 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
// CHECK: %[[tfl_conv2d:[0-9].*]] = "tfl.conv_2d"
// CHECK: tfl.add
// CHECK-NEXT: tfl.add
}

View File

@ -125,3 +125,21 @@ func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112
// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// PerTensor: %[[conv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[dq]]
}
// CHECK-LABEL: QuantizeFullyConnected
func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
%w = constant dense<127.0> : tensor<32x12xf32>
%b = constant dense<0.0> : tensor<32xf32>
%fc = "tfl.fully_connected"(%arg0, %w, %b) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<32x12xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
return %fc : tensor<1x112x112x32xf32>
// CHECK: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<32x12xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>} : (tensor<32x12xf32>) -> tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<32x12xf32>
// CHECK: "tfl.fully_connected"(%arg0, %[[dq]]
// PerTensor: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<32x12xf32>
// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>} : (tensor<32x12xf32>) -> tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>
// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<32x12xf32>
// PerTensor: "tfl.fully_connected"(%arg0, %[[dq]]
}

View File

@ -379,26 +379,26 @@ func @QuantizeConcatResToAllNoRequantize(tensor<1x2x!quant.uniform<u8:f32, 0.1:1
// CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHeCK: return %4 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: return %4 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}
// CHECK-LABEL: QuantizeConcatResToAllRequantize
func @QuantizeConcatResToAllRequantize(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1:128>> {
^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>):
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 2.0:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>
%0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>
%1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>) -> tensor<1x2xf32>
%2 = "tfl.concatenation"(%1, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 2.000000e+00:128>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 2.000000e+00:128>>
// CHECK %1 = "tfl.quantize"(%0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK %2 = "tfl.dequantize"(%1) : (tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<2xf32>
// CHECK %3 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK %4 = "tfl.dequantize"(%3) : (tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<2xf32>
// CHECK %5 = "tfl.concatenation"(%2, %4) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
// CHECK %6 = "tfl.quantize"(%5) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK return %6 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>
// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}
// CHECK-LABEL: QuantizeConcatResToAllRequantizeArg
@ -409,13 +409,13 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform<u8:f32, 2.0:
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK %1 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK %2 = "tfl.dequantize"(%1) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK %3 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK %4 = "tfl.dequantize"(%3) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK %5 = "tfl.concatenation"(%2, %4) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
// CHECK %6 = "tfl.quantize"(%5) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK return %6 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}
// CHECK-LABEL: RequantizeAlreadyQuantizedModel

View File

@ -204,8 +204,9 @@ func @QuantizeConcatRequantize(tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>, tens
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
// CHECK: %[[cc:.*]] = "tfl.concatenation"(%arg0, %[[q]]) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q0]], %[[q1]]) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: return %[[cc]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}

View File

@ -15,11 +15,11 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
namespace tensorflow {

View File

@ -20,11 +20,11 @@ limitations under the License.
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
@ -103,7 +103,7 @@ static int PrintFunctionResultMapping(const std::string &result,
i = 0;
for (auto output : *subgraph->outputs()) {
print_buffer(*subgraph, i, output, [&](int i) {
return terminator ? terminator->getOperand(i)->getLoc() : unknown_loc;
return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
});
}
}

View File

@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Parser.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Parser.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/stream_executor/lib/statusor.h"

View File

@ -21,26 +21,26 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
@ -188,10 +188,10 @@ struct OphintCompositeOp {
// This function will process the aggregated inputs based on different
// strategies like "first", "last", "stack".
std::map<int, Value*> GetAggregatedInputs(OpBuilder* builder) {
std::map<int, Value*> aggregated_inputs;
std::map<int, Value> GetAggregatedInputs(OpBuilder* builder) {
std::map<int, Value> aggregated_inputs;
for (const auto& kv : inputs) {
Value* op_input = nullptr;
Value op_input = nullptr;
const AggregatedOperand& operand = kv.second;
// Dealing with "stack" strategy:
// This breaks into two parts:
@ -203,9 +203,9 @@ struct OphintCompositeOp {
if (operand.ops.size() == 1) {
// If ops size is 1, it will be simply expanding dimensions at dim 0.
Operation* current_identity_op = operand.ops.begin()->second;
Value* input = current_identity_op->getOperand(0);
Value input = current_identity_op->getOperand(0);
RankedTensorType input_type =
input->getType().cast<RankedTensorType>();
input.getType().cast<RankedTensorType>();
// The Reshape will be {1, (original_shape)}
SmallVector<int64_t, 4> reshape_op_shape;
reshape_op_shape.push_back(1);
@ -234,21 +234,21 @@ struct OphintCompositeOp {
} else {
// Insert a pack op to pack all the inputs together.
std::vector<Value*> pack_input_operands;
std::vector<Value*> packed_input_consumers;
std::vector<Value> pack_input_operands;
std::vector<Value> packed_input_consumers;
for (int i = 0, e = operand.ops.size(); i < e; ++i) {
pack_input_operands.push_back(operand.ops.at(i)->getOperand(0));
packed_input_consumers.push_back(operand.ops.at(i)->getResult(0));
}
// Find the first op that consumes the last value of the aggregated
// inputs.
Operation* first_use = *(packed_input_consumers.back()->user_begin());
Operation* first_use = *(packed_input_consumers.back().user_begin());
// The pack reshape will be {N, (original_shape)}
SmallVector<int64_t, 4> pack_shape;
pack_shape.push_back(pack_input_operands.size());
RankedTensorType type = operand.ops.at(0)
->getResult(0)
->getType()
.getType()
.cast<RankedTensorType>();
for (const auto& dim : type.getShape()) {
pack_shape.push_back(dim);
@ -288,9 +288,9 @@ struct OphintCompositeOp {
const AggregatedOperand& operand = kv.second;
if (operand.aggregation == kStrategyStack) {
const int output_numer = operand.ops.size();
Value* first_output = operand.ops.at(0)->getOperand(0);
Value first_output = operand.ops.at(0)->getOperand(0);
RankedTensorType first_output_type =
first_output->getType().cast<RankedTensorType>();
first_output.getType().cast<RankedTensorType>();
// The aggregated output shape will be {N, original_shape}.
SmallVector<int64_t, 4> shape;
shape.push_back(output_numer);
@ -300,12 +300,12 @@ struct OphintCompositeOp {
aggregated_output_types[kv.first] =
RankedTensorType::get(shape, first_output_type.getElementType());
} else if (operand.aggregation == kStrategyLast) {
Value* last_output =
Value last_output =
operand.ops.at(operand.ops.size() - 1)->getOperand(0);
aggregated_output_types[kv.first] = last_output->getType();
aggregated_output_types[kv.first] = last_output.getType();
} else {
Value* first_output = operand.ops.at(0)->getOperand(0);
aggregated_output_types[kv.first] = first_output->getType();
Value first_output = operand.ops.at(0)->getOperand(0);
aggregated_output_types[kv.first] = first_output.getType();
}
}
return aggregated_output_types;
@ -329,7 +329,7 @@ struct OphintCompositeOp {
Operation* first_output = operand.ops.at(0);
Location insert_loc = first_output->getLoc();
SmallVector<Type, 4> unpack_output_types(
output_number, first_output->getOperand(0)->getType());
output_number, first_output->getOperand(0).getType());
builder->setInsertionPoint(first_output);
Operation* unpack_op = builder->create<TFL::UnpackOp>(
@ -404,7 +404,7 @@ void PreprocessTopoSortGraph(
// should only count as one.
llvm::DenseSet<Operation*> input_ops;
for (int i = 0; i < op.getNumOperands(); ++i) {
Operation* input_op = op.getOperand(i)->getDefiningOp();
Operation* input_op = op.getOperand(i).getDefiningOp();
if (input_op) input_ops.insert(input_op);
}
if (input_ops.empty()) {
@ -507,15 +507,15 @@ LogicalResult TopoSortOperations(OpBuilder* builder) {
Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
Operation* insert_before_op,
const std::map<int, Value*>& inputs,
const std::map<int, Value>& inputs,
const std::map<int, Type>& output_types,
OpBuilder* builder, ModuleOp* module_op) {
SmallVector<Type, 4> input_types;
SmallVector<Value*, 4> input_values;
SmallVector<Value, 4> input_values;
SmallVector<int, 4> input_indexes;
for (const auto& kv : inputs) {
Value* input = kv.second;
input_types.push_back(input->getType());
Value input = kv.second;
input_types.push_back(input.getType());
input_values.push_back(input);
input_indexes.push_back(kv.first);
}
@ -588,8 +588,8 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
llvm::DenseSet<Operation*> reachable_ops;
std::queue<Operation*> ops_queue;
for (auto& input_op : input_ops) {
for (Value* value : input_op->getOperands()) {
Operation* op = value->getDefiningOp();
for (Value value : input_op->getOperands()) {
Operation* op = value.getDefiningOp();
if (op != nullptr) ops_queue.push(op);
}
}
@ -598,8 +598,8 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
Operation* current_op = ops_queue.front();
ops_queue.pop();
reachable_ops.insert(current_op);
for (Value* value : current_op->getOperands()) {
Operation* upstream_op = value->getDefiningOp();
for (Value value : current_op->getOperands()) {
Operation* upstream_op = value.getDefiningOp();
// Not visited, put it into the queue.
if (upstream_op != nullptr &&
!llvm::is_contained(reachable_ops, upstream_op)) {
@ -625,7 +625,7 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
BfsForReachableOps(ophint_composite_op.GetAllOutputOps());
// Step 3, deal with inputs aggregation strategies.
const std::map<int, Value*>& aggregated_inputs =
const std::map<int, Value>& aggregated_inputs =
ophint_composite_op.GetAggregatedInputs(builder);
// Step 4, get aggregated output types.
@ -642,7 +642,7 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
aggregated_inputs, aggregated_output_types, builder, module_op);
for (const auto& kv : aggregated_inputs) {
Operation* op = kv.second->getDefiningOp();
Operation* op = kv.second.getDefiningOp();
if (op == nullptr) return failure();
op->moveBefore(fused_op);
}

View File

@ -15,23 +15,23 @@ limitations under the License.
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
@ -92,18 +92,18 @@ LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op,
if (call_op.getNumResults() != 1) return failure();
// Inputs is indexed at 0.
Value* input = call_op.getOperand(0);
Value input = call_op.getOperand(0);
// Input_weight is indexed at 1.
Value* weight = call_op.getOperand(1);
Value weight = call_op.getOperand(1);
// Recurrent_weight is indexed at 2.
Value* recurrent_weight = call_op.getOperand(2);
Value recurrent_weight = call_op.getOperand(2);
// Bias is indexed at 3.
Value* bias = call_op.getOperand(3);
Value bias = call_op.getOperand(3);
// Hidden_state is indexed at 4.
Value* hidden_state = call_op.getOperand(4);
Value hidden_state = call_op.getOperand(4);
// Build Output.
auto output_type = call_op.getResult(0)->getType();
auto output_type = call_op.getResult(0).getType();
// Currently, ophinted RNN only supports time_major = True.
const bool time_major = true;
@ -127,7 +127,7 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
auto input_index_attr = composite_func_op.getAttr(kTfLiteFunctionInputIndex)
.cast<ArrayAttr>()
.getValue();
llvm::DenseMap<int, Value*> fused_ops_index_to_call_op_args;
llvm::DenseMap<int, Value> fused_ops_index_to_call_op_args;
for (int i = 0; i < call_op.getNumOperands(); ++i) {
int input_index = input_index_attr[i].cast<IntegerAttr>().getInt();
@ -139,7 +139,7 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
// We encounter some optional arguments not filled, so we need to create an
// empty Value.
Value* none_value;
Value none_value;
if (call_op.getNumOperands() <
kUnidirectionalSequenceLSTMOpTotalIArgumentNum) {
builder->setInsertionPoint(call_op.getOperation());
@ -148,7 +148,7 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
}
// Prepare all operands for the UnidirectionalSequenceLSTMOp.
SmallVector<Value*, kUnidirectionalSequenceLSTMOpTotalIArgumentNum> operands;
SmallVector<Value, kUnidirectionalSequenceLSTMOpTotalIArgumentNum> operands;
for (int i = 0; i < kUnidirectionalSequenceLSTMOpTotalIArgumentNum; ++i) {
auto operand_it = fused_ops_index_to_call_op_args.find(i);
if (operand_it == fused_ops_index_to_call_op_args.end()) {
@ -169,12 +169,12 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
if (call_op.getNumResults() > 1) {
for (int i = 0; i < call_op.getNumResults() - 1; ++i) {
// This one should not be used.
Value* unused_output = call_op.getResult(i);
if (!unused_output->use_empty()) return failure();
Value unused_output = call_op.getResult(i);
if (!unused_output.use_empty()) return failure();
}
}
output_types.push_back(
call_op.getResult(call_op.getNumResults() - 1)->getType());
call_op.getResult(call_op.getNumResults() - 1).getType());
// Prepare attributes.
SmallVector<NamedAttribute, 4> attributes;
@ -206,11 +206,11 @@ LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name,
LogicalResult build_fused_op_result = BuildUnidirectionalSequenceLSTMOp(
composite_func_op, call_op, builder, &fused_op);
if (failed(build_fused_op_result)) return build_fused_op_result;
Value* call_output = call_op.getResult(call_op.getNumResults() - 1);
if (call_output->getType() != fused_op->getResult(0)->getType()) {
Value call_output = call_op.getResult(call_op.getNumResults() - 1);
if (call_output.getType() != fused_op->getResult(0).getType()) {
return failure();
}
call_output->replaceAllUsesWith(fused_op->getResult(0));
call_output.replaceAllUsesWith(fused_op->getResult(0));
} else { // If we support more fused op, we should add the conversion here.
return failure();
}

View File

@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
// Use the tensor type information from $0 and convert min $1, max $2 and
// numBits $3 and narrowRange $4 to a QuantizedType.
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
"GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
// Converts an integer attribute $0 to 32-bit with builder.
def convertIntAttrTo32Bit : NativeCodeCall<
@ -49,6 +49,11 @@ def convertIntAttrTo32Bit : NativeCodeCall<
def ExtractSingleElementAsInteger : NativeCodeCall<
"ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">;
// Checks whether the given operation has static shapes and same shapes of all inputs.
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;
//===----------------------------------------------------------------------===//
// Nullary ops patterns.
//===----------------------------------------------------------------------===//
@ -145,10 +150,9 @@ def : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
// TODO(jpienaar): this is not true for all selects, TF's select supports rank 0
// condition
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
def : Pat<(TF_SelectV2Op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>;
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>;
def : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>;
def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>;
def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;

View File

@ -28,15 +28,15 @@ limitations under the License.
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -66,6 +66,28 @@ struct LegalizeTF : public FunctionPass<LegalizeTF> {
void runOnFunction() override;
};
// Returns true if all tensor value in `values` has static shape and same shape.
bool HasSameStaticShapes(Operation* op) {
auto values = op->getOperands();
int index = 0;
ArrayRef<int64_t> shape;
for (Value value : values) {
auto shaped_type = value.getType().dyn_cast<ShapedType>();
if (!shaped_type && !shaped_type.hasStaticShape()) {
return false;
}
if (index == 0) {
shape = shaped_type.getShape();
} else {
if (shape != shaped_type.getShape()) {
return false;
}
}
++index;
}
return true;
}
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
#define DECL_CONVERT_OP(tf_op) \
@ -100,7 +122,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
auto tf_concat_op = cast<TF::ConcatOp>(op);
auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output()->getType();
auto output_type = tf_concat_op.output().getType();
// Extract axis attribute from constant concat_dims tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
@ -119,7 +141,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output()->getType();
auto output_type = tf_concat_op.output().getType();
// Extract axis attribute from constant axis tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
@ -145,7 +167,7 @@ PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
if (tf_matmul_op.transpose_a()) return matchFailure();
if (!tf_matmul_op.transpose_b()) return matchFailure();
Type output_type = tf_matmul_op.getResult()->getType();
Type output_type = tf_matmul_op.getResult().getType();
// TODO(jpienaar): Follow up post shuffle discussion.
auto no_input = rewriter.create<ConstantOp>(
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
@ -161,8 +183,8 @@ PatternMatchResult ConvertTFPackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_pack_op = cast<TF::PackOp>(op);
SmallVector<Value*, 4> values(tf_pack_op.values());
auto output_type = tf_pack_op.output()->getType();
SmallVector<Value, 4> values(tf_pack_op.values());
auto output_type = tf_pack_op.output().getType();
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
// Axis can be negative.
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue());
@ -176,10 +198,10 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_reshape_op = cast<TF::ReshapeOp>(op);
auto* input = tf_reshape_op.tensor();
auto* shape = tf_reshape_op.shape();
auto input = tf_reshape_op.tensor();
auto shape = tf_reshape_op.shape();
ShapedType shape_type = shape->getType().cast<ShapedType>();
ShapedType shape_type = shape.getType().cast<ShapedType>();
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
if (!shape_type.getElementType().isInteger(32)) {
auto new_shape = shape_type.getShape();
@ -191,7 +213,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
rewriter.getBoolAttr(false))
.y();
}
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output()->getType(),
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
input, shape);
return matchSuccess();
}
@ -200,7 +222,7 @@ PatternMatchResult ConvertTFSplitOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_split_op = cast<TF::SplitOp>(op);
auto output_types = functional::map([](Value* v) { return v->getType(); },
auto output_types = functional::map([](Value v) { return v.getType(); },
tf_split_op.output());
// Number of splits cannot be negative.
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
@ -215,7 +237,7 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_splitv_op = cast<TF::SplitVOp>(op);
auto output_types = functional::map([](Value* v) { return v->getType(); },
auto output_types = functional::map([](Value v) { return v.getType(); },
tf_splitv_op.output());
// Number of splits cannot be negative.
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
@ -226,13 +248,13 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
return matchSuccess();
}
Value* PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
Value* attribute,
Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
Value attribute,
ArrayRef<int32_t> padding_val, int* mask) {
DenseIntElementsAttr dense_elem_attr;
SmallVector<int32_t, 8> padded_val;
auto ranked_attr_type = attribute->getType().dyn_cast<RankedTensorType>();
auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
if (!ranked_attr_type ||
!matchPattern(attribute, m_Constant(&dense_elem_attr))) {
// If the input attribute is neither ranked type nor constant, we
@ -258,14 +280,14 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
auto ranked_input_type =
tf_strided_slice_op.input()->getType().dyn_cast<RankedTensorType>();
tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
if (!ranked_input_type) {
// If input is not a ranked tensor, we can't deduce the padding dimensions
// from it, so we just do a plain conversion here.
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output()->getType(),
tf_strided_slice_op.input(), tf_strided_slice_op.begin(),
tf_strided_slice_op.end(), tf_strided_slice_op.strides(),
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
tf_strided_slice_op.strides(),
rewriter.getI32IntegerAttr(
tf_strided_slice_op.begin_mask().getSExtValue()),
rewriter.getI32IntegerAttr(
@ -283,20 +305,20 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
// Pad `begin` array with zero values and update the `begin_mask`.
SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0);
int begin_mask = tf_strided_slice_op.begin_mask().getSExtValue();
Value* padded_begin = PadStridedSliceAttributeArray(
Value padded_begin = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask);
// Pad `end` array with `input_shape` and update the `end_mask`.
int end_mask = tf_strided_slice_op.end_mask().getSExtValue();
auto input_shape = ranked_input_type.getShape();
SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end());
Value* padded_end = PadStridedSliceAttributeArray(
Value padded_end = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask);
// Pad `strides` array with ones.
SmallVector<int32_t, 8> strides_pad_val(num_input_dims, 1);
Value* padded_strides = PadStridedSliceAttributeArray(
Value padded_strides = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output()->getType(), tf_strided_slice_op.input(),
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
padded_begin, padded_end, padded_strides,
rewriter.getI32IntegerAttr(begin_mask),
rewriter.getI32IntegerAttr(end_mask),
@ -313,8 +335,8 @@ PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_unpack_op = cast<TF::UnpackOp>(op);
auto* input = tf_unpack_op.value();
auto output_types = functional::map([](Value* v) { return v->getType(); },
auto input = tf_unpack_op.value();
auto output_types = functional::map([](Value v) { return v.getType(); },
tf_unpack_op.output());
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
// Axis can be negative.
@ -338,7 +360,7 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
auto output_type = tf_matrix_diag_v2_or_v3_op.output()->getType();
auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType();
// Extract k constant tensor and check value = 0.
ElementsAttr k;
@ -478,7 +500,7 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
auto status_or_const_op = CreateConstOpWithSingleValue(
&rewriter, op->getLoc(),
tf_reciprocal_op.x()->getType().cast<ShapedType>(), 1);
tf_reciprocal_op.x().getType().cast<ShapedType>(), 1);
if (!status_or_const_op.ok()) {
return matchFailure();
}

View File

@ -19,11 +19,11 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -50,13 +50,13 @@ struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
// Create LSTM gates with different weights for input, recurrent and
// cell state, and also the layer normalization parameters.
Operation* CreateGate(Location loc, Value* in, Value* in_w, Value* rec,
Value* rec_w,
llvm::Optional<std::pair<Value*, Value*>> cell,
Value* ln_w, Value* ln_bias, OpBuilder* builder);
Operation* CreateGate(Location loc, Value in, Value in_w, Value rec,
Value rec_w,
llvm::Optional<std::pair<Value, Value>> cell,
Value ln_w, Value ln_bias, OpBuilder* builder);
Operation* CreateLayerNorm(Location loc, Value* in, Value* ln_w,
Value* ln_bias, OpBuilder* builder);
Operation* CreateLayerNorm(Location loc, Value in, Value ln_w, Value ln_bias,
OpBuilder* builder);
// Add the internal implementation of the LSTM to its regions.
void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder);
@ -71,7 +71,7 @@ struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
Type expressed_type =
lstm.input()->getType().cast<ShapedType>().getElementType();
lstm.input().getType().cast<ShapedType>().getElementType();
Type int8_storage_type = builder->getIntegerType(8);
Type int16_storage_type = builder->getIntegerType(16);
auto flag = quant::QuantizationFlags::FlagValue::Signed;
@ -88,12 +88,12 @@ void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
auto any_int16 = quant::AnyQuantizedType::get(
flag, int16_storage_type, expressed_type, int16_min, int16_max);
int8 = any_int8.castFromExpressedType(lstm.input()->getType());
int16 = any_int16.castFromExpressedType(lstm.input()->getType());
int8 = any_int8.castFromExpressedType(lstm.input().getType());
int16 = any_int16.castFromExpressedType(lstm.input().getType());
}
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in,
Value* ln_w, Value* ln_bias,
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in,
Value ln_w, Value ln_bias,
OpBuilder* builder) {
// Note that l2_normalization and add ops here are not the execution kernel
// implementation for layer_normalization and we just want to use them to
@ -105,8 +105,8 @@ Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in,
}
Operation* LoadQuantizationRecipe::CreateGate(
Location loc, Value* in, Value* in_w, Value* rec, Value* rec_w,
llvm::Optional<std::pair<Value*, Value*>> cell, Value* ln_w, Value* ln_bias,
Location loc, Value in, Value in_w, Value rec, Value rec_w,
llvm::Optional<std::pair<Value, Value>> cell, Value ln_w, Value ln_bias,
OpBuilder* builder) {
auto s1 = builder->create<FullyConnectedOp>(loc, int16, in, in_w, none_cst,
none_af, fc_format, keep_dims);
@ -119,13 +119,13 @@ Operation* LoadQuantizationRecipe::CreateGate(
cell.getValue().second, none_af);
s4 = builder->create<AddNOp>(
loc, int16,
llvm::ArrayRef<Value*>(
llvm::ArrayRef<Value>(
{*s1.output().begin(), *s2.output().begin(), s3.output()}));
} else {
s4 = builder->create<AddNOp>(
loc, int16,
llvm::ArrayRef<Value*>({*s1.output().begin(), *s2.output().begin()}));
llvm::ArrayRef<Value>({*s1.output().begin(), *s2.output().begin()}));
}
auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder);
@ -144,22 +144,20 @@ void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
region.push_back(new Block);
builder->setInsertionPointToEnd(&region.front());
Location loc = lstm.getLoc();
Type int32_type = builder->getIntegerType(32);
Type int32_tensor = UnrankedTensorType::get(int32_type);
none_cst = builder->create<ConstantOp>(loc, builder->getNoneType(),
builder->getUnitAttr());
auto input_gate = CreateGate(
loc, lstm.input(), lstm.input_to_input_weights(),
lstm.input_activation_state(), lstm.recurrent_to_input_weights(),
llvm::Optional<std::pair<Value*, Value*>>(
llvm::Optional<std::pair<Value, Value>>(
{lstm.input_cell_state(), lstm.cell_to_input_weights()}),
lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder);
auto forget_gate = CreateGate(
loc, lstm.input(), lstm.input_to_forget_weights(),
lstm.input_activation_state(), lstm.recurrent_to_forget_weights(),
llvm::Optional<std::pair<Value*, Value*>>(
llvm::Optional<std::pair<Value, Value>>(
{lstm.input_cell_state(), lstm.cell_to_forget_weights()}),
lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder);
@ -179,7 +177,7 @@ void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
auto output_gate = CreateGate(
loc, lstm.input(), lstm.input_to_output_weights(),
lstm.input_activation_state(), lstm.recurrent_to_output_weights(),
llvm::Optional<std::pair<Value*, Value*>>(
llvm::Optional<std::pair<Value, Value>>(
{new_cell, lstm.cell_to_output_weights()}),
lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder);

View File

@ -29,28 +29,28 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
@ -84,7 +84,7 @@ struct LowerStaticTensorListPass
TensorListPatternRewriter *rewriter);
};
Value *CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
ArrayRef<int64_t> shape, int32_t val) {
RankedTensorType type =
RankedTensorType::get(shape, rewriter->getIntegerType(32));
@ -93,9 +93,9 @@ Value *CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
return rewriter->create<ConstantOp>(loc, type, attr);
}
Value *CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
Value *shape_tensor, int32_t val) {
Value *scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
Value shape_tensor, int32_t val) {
Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
return rewriter->create<TF::FillOp>(
loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)),
shape_tensor, scalar_val);
@ -131,32 +131,32 @@ Type GetTensorTypeForTensorList(Type element_type, TF::VariantType handle_dtype,
// Requires that `start_index` and `size` are scalar tensors and
// `item_position_shape` is a 1-D tensor with only one element equal to the rank
// of an item in the tensorlist.
TF::SliceOp CreateSliceOpForTensorList(Location loc, Value *input_list,
Value *start_index, Value *size,
Value *item_rank, Type result_type,
TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
Value start_index, Value size,
Value item_rank, Type result_type,
PatternRewriter *rewriter) {
// Create the start position of slice. This is done by concatenating
// `start_index` and `partial_start_position` together.
IntegerType shape_dtype = rewriter->getIntegerType(32);
RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
Value *partial_start_position =
Value partial_start_position =
CreateI32SplatTensor(loc, rewriter, item_rank, 0);
Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
auto expanded_start_index = rewriter->create<TF::ExpandDimsOp>(
loc, vector_type, start_index, scalar_zero);
auto start_position = rewriter->create<TF::ConcatOp>(
loc, position_type, scalar_zero,
ArrayRef<Value *>({expanded_start_index, partial_start_position}));
ArrayRef<Value>({expanded_start_index, partial_start_position}));
// Create the slice size tensor. This is done by concatenating `size` and
// `partial_size`.
auto size_leading_dim =
rewriter->create<TF::ExpandDimsOp>(loc, vector_type, size, scalar_zero);
Value *partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
auto slice_size = rewriter->create<TF::ConcatOp>(
loc, position_type, scalar_zero,
ArrayRef<Value *>({size_leading_dim, partial_size}));
ArrayRef<Value>({size_leading_dim, partial_size}));
return rewriter->create<TF::SliceOp>(loc, result_type, input_list,
start_position, slice_size);
@ -180,31 +180,31 @@ struct ConvertTensorListSetItem : public ConversionPattern {
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value *> operands,
Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListSetItemOp>(operation);
Location loc = op.getLoc();
Value *input = operands[0];
Value *index = operands[1];
Value *item = operands[2];
Value input = operands[0];
Value index = operands[1];
Value item = operands[2];
IntegerType shape_dtype = rewriter.getIntegerType(32);
auto item_rank = rewriter.create<TF::RankOp>(
loc, RankedTensorType::get({}, shape_dtype), item);
Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
// Calculate `index` + 1, which is used to generate the start position for
// the second slice op.
auto suffix_start =
rewriter.create<TF::AddOp>(loc, index->getType(), index,
rewriter.create<TF::AddOp>(loc, index.getType(), index,
CreateI32SplatConst(loc, &rewriter, {}, 1));
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
// Create two slice ops.
Type element_type = input->getType().cast<TensorType>().getElementType();
Type element_type = input.getType().cast<TensorType>().getElementType();
UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
Value *scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
TF::SliceOp slice1 =
CreateSliceOpForTensorList(loc, /*input_list=*/input,
/*start_index=*/scalar_zero,
@ -225,8 +225,8 @@ struct ConvertTensorListSetItem : public ConversionPattern {
// Concatenate three parts together to generate the final result.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, input->getType(), scalar_zero,
ArrayRef<Value *>({slice1, expanded_item, slice2}));
op, input.getType(), scalar_zero,
ArrayRef<Value>({slice1, expanded_item, slice2}));
return matchSuccess();
}
};
@ -241,14 +241,14 @@ struct ConvertTensorListInitOp : public ConversionPattern {
// Create and return a 1-d tensor with exactly one element equal to the number
// of list elements to initialize the output tensor list with.
virtual Value *GetNumElements(OpT op, ArrayRef<Value *> operands,
virtual Value GetNumElements(OpT op, ArrayRef<Value> operands,
PatternRewriter *rewriter) const = 0;
// Rewrites the original op into `tf.fill`. The result tensor shape is
// [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0.
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value *> operands,
Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OpT op = llvm::cast<OpT>(operation);
@ -263,8 +263,8 @@ struct ConvertTensorListInitOp : public ConversionPattern {
return matchFailure();
}
Value *element_shape = operands[0];
Type shape_dtype = getElementTypeOrSelf(element_shape->getType());
Value element_shape = operands[0];
Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
DenseIntElementsAttr dense_elem_attr;
if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
@ -297,11 +297,10 @@ struct ConvertTensorListInitOp : public ConversionPattern {
new_element_shape_values.push_back(dim_value);
}
auto attr =
DenseIntElementsAttr::get(element_shape->getType().cast<ShapedType>(),
new_element_shape_values);
auto attr = DenseIntElementsAttr::get(
element_shape.getType().cast<ShapedType>(), new_element_shape_values);
auto new_element_shape = rewriter.create<ConstantOp>(
op.getLoc(), element_shape->getType(), attr);
op.getLoc(), element_shape.getType(), attr);
element_shape = new_element_shape;
}
@ -330,11 +329,11 @@ struct ConvertTensorListInitOp : public ConversionPattern {
Location loc = op.getLoc();
// Add number of elements as the prefix to the element shape to get shape of
// the output tensor.
Value *leading_dim = GetNumElements(op, operands, &rewriter);
Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
Value leading_dim = GetNumElements(op, operands, &rewriter);
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
auto list_shape = rewriter.create<TF::ConcatOp>(
loc, shape_type, scalar_zero,
ArrayRef<Value *>({leading_dim, element_shape}));
ArrayRef<Value>({leading_dim, element_shape}));
// Create a zero-initialized constant tensor that has the same type
// as specified by element_dtype.
@ -352,11 +351,11 @@ struct ConvertTensorListReserve
explicit ConvertTensorListReserve(MLIRContext *context)
: ConvertTensorListInitOp(context) {}
Value *GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value *> operands,
Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
PatternRewriter *rewriter) const override {
Value *scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
Type shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
Value *num_elements = operands[1];
Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
Value num_elements = operands[1];
return rewriter->create<TF::ExpandDimsOp>(
op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
scalar_zero);
@ -371,7 +370,7 @@ struct ConvertEmptyTensorList
explicit ConvertEmptyTensorList(MLIRContext *context)
: ConvertTensorListInitOp(context) {}
Value *GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value *> operands,
Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands,
PatternRewriter *rewriter) const override {
return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0);
}
@ -383,23 +382,23 @@ struct ConvertTensorListPushBack : public ConversionPattern {
context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
Value *input_handle = operands[0];
Value *item = operands[1];
Value input_handle = operands[0];
Value item = operands[1];
// Expand the shape of the item so that it will have rank same as the input
// tensor and it is compatible for the Concat Op.
Type expanded_item_type =
PrependLeadingDimIfRanked(1, item->getType(), &rewriter);
Value *scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), expanded_item_type, item, scalar_zero);
Type elem_type = getElementTypeOrSelf(item);
auto handle_dtype =
getElementTypeOrSelf(push_back_op.output_handle()->getType())
getElementTypeOrSelf(push_back_op.output_handle().getType())
.cast<TF::VariantType>();
Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
@ -408,7 +407,7 @@ struct ConvertTensorListPushBack : public ConversionPattern {
// get a tensor equivalent to the TensorList generated by this op.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
push_back_op, result_type, scalar_zero,
ArrayRef<Value *>({input_handle, expanded_item}));
ArrayRef<Value>({input_handle, expanded_item}));
return matchSuccess();
}
};
@ -429,14 +428,14 @@ struct ConvertTensorListResize : public ConversionPattern {
context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
TF::TensorListResizeOp resize_op = cast<TF::TensorListResizeOp>(op);
Value *input_handle = operands[0];
Value *size = operands[1];
Value input_handle = operands[0];
Value size = operands[1];
Location loc = resize_op.getLoc();
Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
// Compute the input tensorlist's length and store it in `input_size`.
IntegerType shape_dtype = rewriter.getIntegerType(32);
@ -446,7 +445,7 @@ struct ConvertTensorListResize : public ConversionPattern {
// Infer result type of this op based on TF's shape inference result.
Type elem_type = getElementTypeOrSelf(input_handle);
auto handle_dtype =
getElementTypeOrSelf(resize_op.output_handle()->getType())
getElementTypeOrSelf(resize_op.output_handle().getType())
.cast<TF::VariantType>();
Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
@ -463,8 +462,8 @@ struct ConvertTensorListResize : public ConversionPattern {
auto input_shape = rewriter.create<TF::ShapeOp>(
loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
Type branch_args_type[] = {input_handle->getType(), input_shape.getType(),
size_diff.getType(), size->getType()};
Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
size_diff.getType(), size.getType()};
Type branch_result_type[] = {result_type};
auto func_type = FunctionType::get(branch_args_type, branch_result_type,
rewriter.getContext());
@ -491,7 +490,7 @@ struct ConvertTensorListResize : public ConversionPattern {
rewriter.replaceOpWithNewOp<TF::IfOp>(
op, result_type, if_cond,
/*input=*/
ArrayRef<Value *>({input_handle, input_shape, size_diff, size}),
ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
/*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op),
/*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
/*output_shapes=*/rewriter.getStrArrayAttr({"{}"}),
@ -517,14 +516,14 @@ struct ConvertTensorListResize : public ConversionPattern {
Location loc = resize_op.getLoc();
// Get the element shape by slicing from index 1 in the input shape.
Value *slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
Value *slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
auto elem_shape = rewriter->create<TF::SliceOp>(
loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
slice_size);
auto extended_part = rewriter->create<TF::TensorListReserveOp>(
loc, resize_op.output_handle()->getType(), elem_shape, size_diff);
loc, resize_op.output_handle().getType(), elem_shape, size_diff);
// `ConcatOp` expects non-variant-typed input. Insert a
// `TensorListStackOp` here to convert type from variant to non-variant.
// Note that we are using the same `result_type` for both the
@ -536,8 +535,8 @@ struct ConvertTensorListResize : public ConversionPattern {
/*num_elements=*/rewriter->getI32IntegerAttr(-1));
auto concat_op = rewriter->create<TF::ConcatOp>(
loc, result_type, scalar_zero,
ArrayRef<Value *>({input, stacked_extended_part}));
rewriter->create<ReturnOp>(loc, ArrayRef<Value *>({concat_op}));
ArrayRef<Value>({input, stacked_extended_part}));
rewriter->create<ReturnOp>(loc, ArrayRef<Value>({concat_op}));
}
void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type,
@ -550,8 +549,8 @@ struct ConvertTensorListResize : public ConversionPattern {
Block *block = branch_func.addEntryBlock();
rewriter->setInsertionPointToStart(block);
Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
Value *vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
auto input = block->getArgument(0);
auto size = block->getArgument(3);
@ -566,7 +565,7 @@ struct ConvertTensorListResize : public ConversionPattern {
/*start_index=*/scalar_zero, /*size=*/size,
/*item_rank=*/partial_position_shape,
/*result_type=*/result_type, rewriter);
rewriter->create<ReturnOp>(loc, ArrayRef<Value *>({slice_op}));
rewriter->create<ReturnOp>(loc, ArrayRef<Value>({slice_op}));
}
};
@ -576,11 +575,11 @@ struct ConvertTensorListGetItem : public ConversionPattern {
context) {}
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value *> operands,
Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListGetItemOp>(operation);
Value *input = operands[0];
Value *index = operands[1];
Value input = operands[0];
Value index = operands[1];
rewriter.replaceOpWithNewOp<TF::GatherOp>(
operation, op.getType(), input, index, rewriter.getBoolAttr(true));
return matchSuccess();
@ -593,11 +592,11 @@ struct ConvertTensorListLength : public ConversionPattern {
context) {}
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value *> operands,
Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListLengthOp>(operation);
Location loc = op.getLoc();
Value *input_handle = operands[0];
Value input_handle = operands[0];
BoolAttr true_attr = rewriter.getBoolAttr(true);
auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
@ -615,19 +614,19 @@ struct ConvertTensorListStack : public ConversionPattern {
context) {}
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value *> operands,
Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListStackOp>(operation);
Location loc = op.getLoc();
Value *input = operands[0];
Value *element_shape = operands[1];
Value input = operands[0];
Value element_shape = operands[1];
// If the `element_shape` is a known constant (which is defined when calling
// `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a
// trivial Reshape op (that doesn't actually change the input's shape) and
// also populate the shape info to the op result. The shape of the
// tensorlist is inferred from `num_elements` and `element_shape`.
auto ranked_type = element_shape->getType().dyn_cast<RankedTensorType>();
auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
DenseIntElementsAttr dense_elem_attr;
if ((ranked_type && ranked_type.getRank() == 0) ||
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
@ -655,11 +654,11 @@ struct ConvertIdentity : public ConversionPattern {
: ConversionPattern(TF::IdentityOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value *> operands,
Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::IdentityOp>(operation);
Value *input = operands[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input->getType(), operands,
Value input = operands[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
op.getAttrs());
return matchSuccess();
}
@ -687,7 +686,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
Type arg_type = func_type.getInput(i);
if (getElementTypeOrSelf(arg_type).isa<TF::VariantType>()) {
arg_type = UnrankedTensorType::get(
getElementTypeOrSelf(op.getOperand(i)->getType()));
getElementTypeOrSelf(op.getOperand(i).getType()));
}
updated_argument_types.push_back(arg_type);
}
@ -703,7 +702,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
// from the corresponding input operand. This is correct because while
// body's inputs and results have the same type.
result_type = UnrankedTensorType::get(
getElementTypeOrSelf(op.getOperand(i)->getType()));
getElementTypeOrSelf(op.getOperand(i).getType()));
}
updated_result_types.push_back(result_type);
}
@ -717,7 +716,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
// Change the argument type for the first block.
Block &body_first_bb = func.front();
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
body_first_bb.getArgument(i)->setType(updated_argument_types[i]);
body_first_bb.getArgument(i).setType(updated_argument_types[i]);
}
}
return success();
@ -728,19 +727,19 @@ struct ConvertWhile : public ConversionPattern {
: ConversionPattern(TF::WhileOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value *> operands,
Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::WhileOp>(operation);
llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands());
for (int i = 0, e = operands.size(); i != e; ++i) {
Type result_ty = op.getResult(i)->getType();
Type result_ty = op.getResult(i).getType();
// If we notice the result type is a DT_VARIANT, we change the
// corresponding result type to unranked tensor type.
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
Type element_ty = getElementTypeOrSelf(operands[i]->getType());
Type element_ty = getElementTypeOrSelf(operands[i].getType());
result_ty = UnrankedTensorType::get(element_ty);
}
result_types.push_back(result_ty);

View File

@ -30,14 +30,14 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
@ -50,16 +50,16 @@ namespace TFL {
// The actual Optimize Pass.
namespace {
bool L2NormalizeReduceAxis(Value *sq_op, DenseElementsAttr axis) {
if (sq_op->getType().cast<ShapedType>().getRank() - 1 ==
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
*axis.getValues<int>().begin() ||
*axis.getValues<int>().begin() == -1) {
return true;
}
if (sq_op->getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
return false;
}
auto shape = sq_op->getType().cast<ShapedType>();
auto shape = sq_op.getType().cast<ShapedType>();
SmallVector<int, 4> elems{axis.getValues<int>().begin(),
axis.getValues<int>().end()};
for (int i = 0; i < shape.getRank(); ++i) {
@ -142,8 +142,8 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
// Returns shape of a ranked tensor.
// Precondition: output_val's is ranked tensor.
DenseElementsAttr GetShape(Value *output_val) {
auto output_type = output_val->getType().cast<RankedTensorType>();
DenseElementsAttr GetShape(Value output_val) {
auto output_type = output_val.getType().cast<RankedTensorType>();
auto shape_vector = output_type.getShape();
std::vector<int32_t> shape(shape_vector.size());
for (int i = 0; i < shape_vector.size(); ++i) {
@ -152,7 +152,7 @@ DenseElementsAttr GetShape(Value *output_val) {
return mlir::DenseElementsAttr::get(
RankedTensorType::get(
{static_cast<int>(shape.size())},
mlir::IntegerType::get(32, output_val->getContext())),
mlir::IntegerType::get(32, output_val.getContext())),
llvm::makeArrayRef(shape));
}
@ -167,19 +167,19 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
PatternRewriter &rewriter) const override {
// Add.
DenseElementsAttr added_value;
Value *constant_val = add_op.rhs();
Value constant_val = add_op.rhs();
if (!matchPattern(constant_val, m_Constant(&added_value)))
return matchFailure();
// Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs()->getDefiningOp());
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
if (!fc_op) return matchFailure();
Value *filter = fc_op.filter();
Value *bias = fc_op.bias();
Value filter = fc_op.filter();
Value bias = fc_op.bias();
ElementsAttr bias_value;
const bool is_none_bias = bias->getType().isa<NoneType>();
const bool is_none_bias = bias.getType().isa<NoneType>();
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
return matchFailure();
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
@ -213,7 +213,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op,
PatternRewriter &rewriter) const override {
Operation *input = relu_op.getOperand()->getDefiningOp();
Operation *input = relu_op.getOperand().getDefiningOp();
if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure();
auto fully_connected_op = cast<FullyConnectedOp>(input);
if (fully_connected_op.fused_activation_function() != "NONE")
@ -242,18 +242,18 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
PatternRewriter &rewriter) const override {
// Mul.
DenseElementsAttr cst;
Value *constant_val = mul_op.rhs();
Value constant_val = mul_op.rhs();
if (!matchPattern(constant_val, m_Constant(&cst))) return matchFailure();
// Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp());
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
if (!fc_op) return matchFailure();
Value *filter = fc_op.filter();
Value *bias = fc_op.bias();
Value filter = fc_op.filter();
Value bias = fc_op.bias();
ElementsAttr cst_tmp;
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
if (!bias->getType().isa<NoneType>() &&
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&cst_tmp)))
return matchFailure();
if (fc_op.fused_activation_function().equals("None")) return matchFailure();
@ -261,8 +261,8 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
// Broadcast the constant operand of Mul if it isn't compatible to the
// filter input. We only support broadcasting the operand along the depth
// dimension, when the operand's depth is 1.
Value *new_const_val = constant_val;
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) {
Value new_const_val = constant_val;
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) {
auto original_shape = cst.getType().getShape();
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
original_shape.end());
@ -270,7 +270,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
auto new_cst = cst.reshape(RankedTensorType::get(
normalized_shape, cst.getType().getElementType()));
Type new_type = new_cst.getType();
if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) {
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
return matchFailure();
}
auto new_op =
@ -285,7 +285,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
auto new_filter =
rewriter.create<TF::MulOp>(loc, filter, new_const_val).z();
// If bias isn't None, it needs to be multiplied as well.
if (!bias->getType().isa<NoneType>()) {
if (!bias.getType().isa<NoneType>()) {
bias = rewriter.create<TF::MulOp>(loc, bias, constant_val).z();
}
@ -311,7 +311,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
PatternMatchResult matchAndRewrite(AffineOpType fc_op,
PatternRewriter &rewriter) const override {
// Binary op.
Operation *binary_op = fc_op.input()->getDefiningOp();
Operation *binary_op = fc_op.input().getDefiningOp();
if (!binary_op || binary_op->getNumOperands() != 2)
return this->matchFailure();
// We only handle the cases the RHS is a scalar.
@ -325,20 +325,20 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
APFloat cst_value = *cst.float_value_begin();
// Affine op.
Value *filter = fc_op.filter();
Value *bias = fc_op.bias();
Value filter = fc_op.filter();
Value bias = fc_op.bias();
DenseFPElementsAttr filter_cst, bias_cst;
if (!matchPattern(filter, m_Constant(&filter_cst))) {
// The filter maybe quantized, then we should set it to the real constant.
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter->getDefiningOp());
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
if (!dq) return this->matchFailure();
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input()->getDefiningOp());
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
return this->matchFailure();
}
filter = q.input();
}
if (!bias->getType().isa<NoneType>() &&
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&bias_cst)))
return this->matchFailure();
ShapedType filter_type = filter_cst.getType();
@ -362,7 +362,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
// The new bias should be a 1-D tensor with length equals to the bias
// dimension of the weight.
SmallVector<APFloat, 4> new_bias_values;
if (bias->getType().isa<NoneType>()) { // none bias, a list of zeros
if (bias.getType().isa<NoneType>()) { // none bias, a list of zeros
new_bias_values.resize(bias_size, APFloat(0.0));
} else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
@ -401,12 +401,12 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
// We recreate the constant op in case it is shared by the other ops. This
// might increase the model size.
auto new_filter_op = rewriter.create<ConstOp>(
fc_op.getLoc(), filter->getType(), new_filter);
fc_op.getLoc(), filter.getType(), new_filter);
fc_op.setOperand(0, binary_op->getOperand(0));
if (fc_op.filter() != filter) {
// This filter goes through quantize and dequantize ops. Then we just
// need to update the weight to the quantize op.
filter->replaceAllUsesWith(new_filter_op);
filter.replaceAllUsesWith(new_filter_op);
} else {
// This filter doesn't go through quantize and dequantize ops, Then
// we update the weight of the affine op directly.

View File

@ -17,15 +17,15 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
@ -98,13 +98,13 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
for (int i = 0, e = func.getNumArguments(); i != e; ++i)
mapper.map(func.getArgument(i), op.getOperand(i + 1));
llvm::SmallVector<Value*, 4> updated_results;
llvm::SmallVector<Value, 4> updated_results;
for (auto& op_to_inline : func.getBody().front()) {
// If this is a terminator, identify the values to use to replace the
// original If op.
if (op_to_inline.isKnownTerminator()) {
updated_results.reserve(op_to_inline.getNumOperands());
for (Value* operand : op_to_inline.getOperands())
for (Value operand : op_to_inline.getOperands())
updated_results.push_back(mapper.lookup(operand));
break;
}

View File

@ -18,6 +18,7 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
@ -53,13 +54,15 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu1Op, TFL_AF_Relu1]] in
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// If we see a binary op (add, sub) op adding a constant value to a convolution
// op with constant bias, we can fuse the binary op into the convolution op by
// constant folding the bias and the binary op's constant operand. The following
// pattern restricts to float constant values for now.
multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
def : Pat<(binaryOp (TFL_Conv2DOp $input, $filter,
def : Pat<(binaryOp (TFL_Conv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w),
@ -68,8 +71,9 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
(binaryOp (ConstantOp $bias),
(ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w)>;
def : Pat<(binaryOp (TFL_DepthwiseConv2DOp $input, $filter,
$padding, $stride_h, $stride_w),
[(HasOneUse $output)]>;
def : Pat<(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
@ -81,7 +85,8 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
$multiplier)>;
$multiplier),
[(HasOneUse $output)]>;
}
foreach binaryOp = [TFL_AddOp, TFL_SubOp] in
defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
@ -101,7 +106,7 @@ def ExpandTo4DForDepthwiseConv: NativeCodeCall<
// The following pattern restricts to float constant values for now.
multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp $input,
def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
@ -119,8 +124,9 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
$multiplier),
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value)]>;
def : Pat<(BinaryOp (TFL_Conv2DOp $input,
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
(HasOneUse $output)]>;
def : Pat<(BinaryOp (TFL_Conv2DOp:$conv_output $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
@ -135,7 +141,8 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value)]>;
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
(HasOneUse $conv_output)]>;
}
foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in
@ -154,7 +161,7 @@ def EqualOperands : Constraint<CPred<"$0 == $1">>;
// Checks if the operand has rank == n
class OperandHasRank<int n> : Constraint<
CPred<"$0->getType().cast<ShapedType>().getRank() == " # n>>;
CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>;
// Matching HardSwish
def : Pat<
@ -249,7 +256,7 @@ foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
def AreBroadcastableTypes : Constraint<CPred<
"TFL::IsBroadcastableElementsAttrAndType($0->getType(), $1->getType())">>;
"TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>;
// Pattern for skipping Tile if it is mainly for broadcasting and the
// Op is already supporting broadcasting.
@ -307,3 +314,7 @@ multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
TFL_MulOp, TFL_SubOp] in
defm : FusedBinaryActivationFuncOpPat<BinaryOps>;
// The constant folding in this pass might produce constant in the tf dialect.
// This rule is to legalize these constant to the tfl dialect.
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;

View File

@ -16,8 +16,8 @@ limitations under the License.
// This transformation pass applies some clean up steps after quantization.
#include "llvm/Support/Casting.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -67,33 +67,33 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
// In each iteration, a new argument is appended to the end of the list
// and the current argument is erased, so here we always process the first
// argument in the list.
auto* arg = bb.getArgument(0);
auto arg = bb.getArgument(0);
auto remove_quantize_op = [&](QuantizeOp quantize_op) {
auto quantize_output = quantize_op.output();
auto quantize_type = quantize_output->getType();
auto quantize_type = quantize_output.getType();
input_types.push_back(quantize_type);
auto* new_arg = bb.addArgument(quantize_type);
quantize_output->replaceAllUsesWith(new_arg);
auto new_arg = bb.addArgument(quantize_type);
quantize_output.replaceAllUsesWith(new_arg);
quantize_op.erase();
arg->dropAllUses();
arg.dropAllUses();
bb.eraseArgument(0);
};
// This is looking for a pattern: arg -> tfl.quantize
if (arg->hasOneUse() && llvm::isa<QuantizeOp>(*arg->user_begin())) {
auto quantize_op = llvm::cast<QuantizeOp>(*arg->user_begin());
if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
remove_quantize_op(quantize_op);
continue;
}
// Make a copy of current argument and append it to the end of the list if
// the pattern isn't found.
Type arg_type = arg->getType();
Type arg_type = arg.getType();
input_types.push_back(arg_type);
auto* new_arg = bb.addArgument(arg_type);
arg->replaceAllUsesWith(new_arg);
arg->dropAllUses();
auto new_arg = bb.addArgument(arg_type);
arg.replaceAllUsesWith(new_arg);
arg.dropAllUses();
bb.eraseArgument(0);
}
@ -102,16 +102,16 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
llvm::SmallVector<Type, 4> output_types;
output_types.reserve(num_return_operands);
for (int i = 0; i != num_return_operands; ++i) {
auto* returned_value = terminator->getOperand(i);
Operation* returned_op = returned_value->getDefiningOp();
auto returned_value = terminator->getOperand(i);
Operation* returned_op = returned_value.getDefiningOp();
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
Value* dequantized_result = dequantize_op.input();
output_types.push_back(dequantized_result->getType());
Value dequantized_result = dequantize_op.input();
output_types.push_back(dequantized_result.getType());
terminator->setOperand(i, dequantized_result);
returned_op->erase();
} else {
output_types.push_back(returned_value->getType());
output_types.push_back(returned_value.getType());
}
}
auto new_func_type = builder.getFunctionType(input_types, output_types);

View File

@ -22,19 +22,19 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Identifier.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
@ -53,8 +53,8 @@ class ConvertEmbeddedLookupFunc {
void RewriteFunc() {
func_.setAttr(kTFImplements,
StringAttr::get("embedding_lookup", func_.getContext()));
Value* lookup = func_.getArgument(1);
Value* value = func_.getArgument(0);
Value lookup = func_.getArgument(1);
Value value = func_.getArgument(0);
auto output_type = func_.getType().getResult(0);
OpBuilder builder(func_.getBody());

View File

@ -135,10 +135,10 @@ def : Pat<(TF_ReshapeOp
// Casts result type of $1 to a quantized type by using the quantization
// parameters from the type in $0.
class UpdateShapeWithAxis<int i> : NativeCodeCall<
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1->getType(), " # i # ")">;
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">;
class UsedBy<string op> : Constraint<
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0->getUsers().begin())">>;
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0.getUsers().begin())">>;
// When the op is passing-through, the output types of the quantized ops need
// to be updated as well. Since the quantize op manages its own type by the

View File

@ -21,10 +21,10 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
@ -139,7 +139,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
BoolAttr narrow_range = builder.getBoolAttr(false);
auto add_quantize_op = [&](Location loc, Type input_type, Block* block,
Block::iterator insertion_point, Value* arg,
Block::iterator insertion_point, Value arg,
int i) {
if (auto shaped = input_type.dyn_cast<ShapedType>()) {
if (shaped.getElementType().isa<FloatType>()) {
@ -153,16 +153,16 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
params);
auto dq_op =
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
arg->replaceAllUsesWith(dq_op.output());
arg.replaceAllUsesWith(dq_op.output());
q_op.setOperand(arg);
}
}
};
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
BlockArgument* arg = func.getArgument(i);
auto* arg_block = arg->getOwner();
add_quantize_op(arg->getLoc(), arg->getType(), arg_block,
BlockArgument arg = func.getArgument(i);
auto* arg_block = arg.getOwner();
add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
std::next(arg_block->begin(), i), arg, i);
}

View File

@ -38,17 +38,17 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -115,17 +115,17 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
PatternRewriter &rewriter) const override {
// We don't want to insert quantize/dequantize if the quantize op exists.
auto res = tf_op.outputs();
if (!res->hasOneUse() || isa<QuantizeOp>(*res->user_begin()))
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin()))
return this->matchFailure();
// Extract the min/max constant values from the operands. We also consider
// a special case that there are tf.Identity ops between the min/max
// constants and the tf.FakeQuantWithMinMaxVarsOp.
Value *min = tf_op.min(), *max = tf_op.max();
Value min = tf_op.min(), max = tf_op.max();
DenseFPElementsAttr min_value, max_value;
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min->getDefiningOp()))
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
min = id1.input();
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max->getDefiningOp()))
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
max = id2.input();
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
@ -133,7 +133,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
int quant_dim = -1;
if (PerAxis) {
// This is a special case that the quant_dim is the last dimensions.
quant_dim = res->getType().template cast<ShapedType>().getRank() - 1;
quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
}
// Use the min/max from the operands and the num_bits and narrow_range
// attribute to create the quantization parameter for the new quantize op.
@ -150,12 +150,12 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
// Finally, use the quantization parameter to create the quantize and
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
// and its users.
Value *value = tf_op.outputs();
Value value = tf_op.outputs();
auto quantize = rewriter.create<TFL::QuantizeOp>(
tf_op.getLoc(), qtype.getValue(), value, qtype);
auto dequantize = rewriter.create<TFL::DequantizeOp>(
tf_op.getLoc(), res_type, quantize.output());
value->replaceAllUsesWith(dequantize);
value.replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
return this->matchSuccess();
@ -177,8 +177,8 @@ using PreparePerChannelFakeQuant =
//
// TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
// PatternRewriter &rewriter, Location loc,
// Type result_type, Value *input,
// Value *filter, Value *bias) const;
// Type result_type, Value input,
// Value filter, Value bias) const;
//
// And also the following method for getting the dimension for bias tensor:
//
@ -240,7 +240,7 @@ struct ConvertTFConvOp : public RewritePattern {
// that we can extract info from the shape (e.g., for constructing bias
// tensor, for setting depth_multiplier attribute, etc.).
auto filter_type =
tf_op.filter()->getType().template dyn_cast<RankedTensorType>();
tf_op.filter().getType().template dyn_cast<RankedTensorType>();
if (filter_type && filter_type.getRank() == 4)
return matchSuccess(std::move(state));
@ -262,7 +262,7 @@ struct ConvertTFConvOp : public RewritePattern {
// Get a splat zero tensor with the expected dimension for the bias tensor
auto filter = tf_op.filter();
auto filter_type = filter->getType().template cast<RankedTensorType>();
auto filter_type = filter.getType().template cast<RankedTensorType>();
auto elem_type = filter_type.getElementType();
auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
filter_type.getShape());
@ -294,8 +294,8 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
PatternRewriter &rewriter, Location loc,
Type result_type, Value *input, Value *filter,
Value *bias) const {
Type result_type, Value input, Value filter,
Value bias) const {
filter = legalizeFilter(rewriter, loc, filter);
return rewriter.create<TFL::Conv2DOp>(
loc, result_type, input, filter, bias,
@ -312,8 +312,8 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
// format HWIO to TFLite Conv2D op filter data format OHWI and return Value
// for the converted filter. Requires that filter is verified by the match
// method that it is a 4-D RankedTensorType.
Value *legalizeFilter(PatternRewriter &rewriter, Location loc,
Value *filter) const {
Value legalizeFilter(PatternRewriter &rewriter, Location loc,
Value filter) const {
// Create a constant op for HWIO to OHWI transpose permutation.
SmallVector<int, 4> perm = {3, 0, 1, 2};
auto perm_type = RankedTensorType::get({static_cast<int>(perm.size())},
@ -323,7 +323,7 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
// Create tensor type for the transpose result.
auto filter_type = filter->getType().cast<RankedTensorType>();
auto filter_type = filter.getType().cast<RankedTensorType>();
auto result_shape = functional::map(
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
perm);
@ -349,14 +349,14 @@ class ConvertTFDepthwiseConv2dNative
TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
PatternRewriter &rewriter, Location loc,
Type result_type, Value *input,
Value *filter, Value *bias) const {
Type result_type, Value input,
Value filter, Value bias) const {
// Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
// 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
// have a corresponding 'depth_multiplier' attribute; the multiplier is the
// fourth dimension in the 4-D filter tensor. We query the multiplier from
// tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
auto multiplier = filter->getType().cast<RankedTensorType>().getDimSize(3);
auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
filter = legalizeFilter(rewriter, loc, filter);
return rewriter.create<TFL::DepthwiseConv2DOp>(
@ -378,9 +378,9 @@ class ConvertTFDepthwiseConv2dNative
/// filter data format is [1, filter_height, filter_width, out_channels].
/// Requires that filter is verified by the match method that it is a 4-D
/// RankedTensorType.
Value *legalizeFilter(PatternRewriter &rewriter, Location loc,
Value *filter) const {
auto filter_type = filter->getType().cast<RankedTensorType>();
Value legalizeFilter(PatternRewriter &rewriter, Location loc,
Value filter) const {
auto filter_type = filter.getType().cast<RankedTensorType>();
auto filterShape = filter_type.getShape();
SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
filterShape[2] * filterShape[3]};
@ -430,13 +430,13 @@ struct ConvertTFStridedSlice : public RewritePattern {
if (new_axis_mask == 0) return matchFailure();
// Insert a new reshape op.
Value *original_input = strided_slice_op.input();
Value original_input = strided_slice_op.input();
RankedTensorType original_input_type =
original_input->getType().cast<RankedTensorType>();
original_input.getType().cast<RankedTensorType>();
const ArrayRef<int64_t> &original_input_shape =
original_input_type.getShape();
RankedTensorType begin_type =
strided_slice_op.begin()->getType().cast<RankedTensorType>();
strided_slice_op.begin().getType().cast<RankedTensorType>();
const int dim_size = begin_type.getShape()[0];
SmallVector<int64_t, 4> new_shape;
int mask = 1;

Some files were not shown because too many files have changed in this diff Show More