Merge pull request #3 from tensorflow/master

post 2.1
This commit is contained in:
Basit Ayantunde 2020-01-09 12:23:25 +01:00 committed by GitHub
commit be6e1ce49a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1577 changed files with 28547 additions and 175090 deletions

View File

@ -381,9 +381,9 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:toolchain"
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:cc-toolchain-x64_windows"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"

View File

@ -1 +1 @@
1.1.0
1.2.1

View File

@ -29,20 +29,6 @@ to
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
See all the [mailing lists](https://www.tensorflow.org/community/forums).
## Feature Prioritization Survey
The TensorFlow team is working on building/improving features, and understands
that it is very important to prioritize these efforts based on what TF users
need.
The goal of this short, < 5 minute
[survey](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad), is to help
the TensorFlow team better understand what features to prioritize based on your
feedback. Participation is of course optional.
Take the survey
[HERE](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad).
## Install
See the [TensorFlow install guide](https://www.tensorflow.org/install) for the
@ -164,4 +150,3 @@ Learn more about the
## License
[Apache License 2.0](LICENSE)

View File

@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '1.0.0'
_TF_MAX_BAZEL_VERSION = '1.1.0'
_TF_MIN_BAZEL_VERSION = '1.2.1'
_TF_MAX_BAZEL_VERSION = '1.2.1'
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -147,14 +147,16 @@ def write_action_env_to_bazelrc(var_name, var):
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
def run_shell(cmd, allow_non_zero=False):
def run_shell(cmd, allow_non_zero=False, stderr=None):
if stderr is None:
stderr = sys.stdout
if allow_non_zero:
try:
output = subprocess.check_output(cmd)
output = subprocess.check_output(cmd, stderr=stderr)
except subprocess.CalledProcessError as e:
output = e.output
else:
output = subprocess.check_output(cmd)
output = subprocess.check_output(cmd, stderr=stderr)
return output.decode('UTF-8').strip()
@ -169,10 +171,12 @@ def get_python_path(environ_cp, python_bin_path):
if environ_cp.get('PYTHONPATH'):
python_paths = environ_cp.get('PYTHONPATH').split(':')
try:
stderr = open(os.devnull, 'wb')
library_paths = run_shell([
python_bin_path, '-c',
'import site; print("\\n".join(site.getsitepackages()))'
]).split('\n')
],
stderr=stderr).split('\n')
except subprocess.CalledProcessError:
library_paths = [
run_shell([

View File

@ -458,7 +458,7 @@ static void TF_Run_Helper(
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
continue;
}
c_outputs[i] = TF_TensorFromTensor(src, status);
c_outputs[i] = TF_TensorFromTensor(src, &status->status);
if (!status->status.ok()) return;
}
}
@ -1493,7 +1493,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
if (!status->status.ok()) return;
*value = TF_TensorFromTensor(t, status);
*value = TF_TensorFromTensor(t, &status->status);
}
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
@ -1504,7 +1504,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) {
values[i] = TF_TensorFromTensor(ts[i], status);
values[i] = TF_TensorFromTensor(ts[i], &status->status);
}
}
@ -2398,7 +2398,7 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
graph->graph.versions().producer(), &evaluated, &result_tensor);
if (evaluated) {
DCHECK(status->status.ok());
*result = TF_TensorFromTensor(result_tensor, status);
*result = TF_TensorFromTensor(result_tensor, &status->status);
if (!status->status.ok()) evaluated = false;
}
return evaluated;

View File

@ -634,7 +634,7 @@ TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status);
if (!status->status.ok()) return nullptr;
return tensorflow::TF_TensorFromTensor(*tensor, status);
return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
}
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,

View File

@ -188,7 +188,7 @@ namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out);

View File

@ -51,7 +51,7 @@ limitations under the License.
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace {
@ -227,7 +227,7 @@ TEST(CAPI, LibraryLoadFunctions) {
void TestEncodeDecode(int line, const std::vector<string>& data) {
const tensorflow::int64 n = data.size();
TF_Status* status = TF_NewStatus();
Status status;
for (const std::vector<tensorflow::int64>& dims :
std::vector<std::vector<tensorflow::int64>>{
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
@ -236,8 +236,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
src.flat<tstring>()(i) = data[i];
}
TF_Tensor* dst = TF_TensorFromTensor(src, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* dst = TF_TensorFromTensor(src, &status);
ASSERT_TRUE(status.ok()) << status.error_message();
// Convert back to a C++ Tensor and ensure we get expected output.
Tensor output;
@ -249,7 +249,6 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
TF_DeleteTensor(dst);
}
TF_DeleteStatus(status);
}
TEST(CAPI, TensorEncodeDecodeStrings) {
@ -1394,8 +1393,9 @@ TEST(CAPI, SavedModel) {
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
Status status;
csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
ASSERT_TRUE(status.ok()) << status.error_message();
const tensorflow::string output_op_name(
tensorflow::ParseTensorName(output_name).first);
@ -2522,12 +2522,11 @@ TEST(CAPI, TestTensorIsNotAligned) {
// Take an unaligned slice.
Tensor y = x.Slice(1, 13);
TF_Status* status = TF_NewStatus();
TF_Tensor* a = TF_TensorFromTensor(y, status);
Status status;
TF_Tensor* a = TF_TensorFromTensor(y, &status);
if (EIGEN_MAX_ALIGN_BYTES > 0) {
EXPECT_FALSE(TF_TensorIsAligned(a));
}
TF_DeleteStatus(status);
TF_DeleteTensor(a);
}

View File

@ -464,7 +464,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
&new_remote_device_mgr));
remote_device_mgr = new_remote_device_mgr.get();
} else {
ctx->context->ClearCaches();
ctx->context->ClearCachesAndDefaultExecutor();
// TODO(b/143914772): Potential memory leak if rendezvous has pending
// tensors for removed / replaced workers.
@ -638,7 +638,7 @@ tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
void OpInferSingleTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs,
const tensorflow::DataType dtype,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
@ -646,26 +646,20 @@ void OpInferSingleTypeInputListAttrs(TFE_Op* op,
ictx->attrs.insert(input_def.number_attr());
}
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.type_attr(),
inputs[0]->handle->dtype);
op->operation.MutableAttrs()->Set(input_def.type_attr(), dtype);
ictx->attrs.insert(input_def.type_attr());
}
}
void OpInferMixedTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs, int num_inputs) {
void OpInferMixedTypeInputListAttrs(
TFE_Op* op, const tensorflow::OpDef::ArgDef& input_def,
const std::vector<tensorflow::DataType>& dtypes) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
std::unique_ptr<tensorflow::DataType[]> dtypes(
new tensorflow::DataType[num_inputs]);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs[i]->handle->dtype;
}
op->operation.MutableAttrs()->Set(
input_def.type_list_attr(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
num_inputs));
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.data(),
dtypes.size()));
ictx->attrs.insert(input_def.type_list_attr());
}
}
@ -675,10 +669,15 @@ tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.type_list_attr().empty()) {
OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
std::vector<tensorflow::DataType> dtypes(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs[i]->handle->dtype;
}
OpInferMixedTypeInputListAttrs(op, input_def, dtypes);
} else if (!input_def.type_attr().empty() &&
!input_def.number_attr().empty()) {
OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
OpInferSingleTypeInputListAttrs(op, input_def, inputs[0]->handle->dtype,
num_inputs);
} else {
return tensorflow::errors::InvalidArgument("Invalid input list definition");
}
@ -754,7 +753,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
return list;
}
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); }
void TFE_ContextClearCaches(TFE_Context* ctx) {
ctx->context->ClearCachesAndThreadExecutors();
}
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
@ -990,7 +991,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
h_cpu->Unref();
return nullptr;
}
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, &status->status);
h_cpu->Unref();
return retval;
} else {
@ -1006,7 +1007,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
if (!status->status.ok()) return nullptr;
}
return tensorflow::TF_TensorFromTensor(tensor, status);
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
}
}

View File

@ -206,7 +206,7 @@ typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
// error and nullptr is returned. This function can block till the operation
// that produces `handle` has completed.
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status);
TFE_TensorHandle* h, TF_Status* status);
// Deletes `debug_info`.
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(

View File

@ -50,15 +50,15 @@ std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status) {
TFE_TensorHandle* h, TF_Status* status) {
const tensorflow::Tensor* tensor;
status->status = handle->handle->Tensor(&tensor);
status->status = h->handle->Tensor(&tensor);
if (TF_GetCode(status) != TF_OK) {
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = handle->handle->device();
tensorflow::Device* device = h->handle->device();
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =
@ -72,7 +72,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
return nullptr;
}
if (VLOG_IS_ON(3)) {
std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status);
std::vector<int64> shape_to_log = TensorShapeAsVector(h, status);
if (!status->status.ok()) {
// Ignore the status here as we are simply logging.
status->status = tensorflow::Status::OK();
@ -138,7 +138,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
std::vector<int64> dev_dims = TensorShapeAsVector(h, status);
if (TF_GetCode(status) != TF_OK) {
return nullptr;
}

View File

@ -27,11 +27,18 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/command_line_flags.h"
// TODO(b/143949264): Testing is not yet supported on Windows. Will implement
// testing on Windows when implementing modular filesystems on Windows.
#if defined(PLATFORM_WINDOWS)
#error Windows is not yet supported. Need mkdir().
#endif
// Make mkdir resolve to _mkdir to create the test temporary directory.
#include <direct.h>
#define mkdir(name, mode) _mkdir(name)
// Windows defines the following macros to convert foo to fooA or fooW,
// depending on the type of the string argument. We don't use these macros, so
// undefine them here.
#undef LoadLibrary
#undef CopyFile
#undef DeleteFile
#endif // defined(PLATFORM_WINDOWS)
// The tests defined here test the compliance of filesystems with the API
// defined by `filesystem_interface.h`.
@ -86,9 +93,6 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
}
void SetUp() override {
// TODO(b/143949264): Testing is not yet supported on Windows. Will
// implement testing on Windows when implementing modular filesystems on
// Windows.
if (mkdir(root_dir_.c_str(), 0755) != 0) {
int error_code = errno;
GTEST_SKIP() << "Cannot create working directory: "
@ -1668,30 +1672,40 @@ static std::vector<std::string>* SchemeVector() {
return schemes;
}
static std::vector<std::string> GetSchemes() {
std::vector<std::string>* user_schemes = SchemeVector();
std::vector<std::string> all_schemes;
// `INSTANTIATE_TEST_SUITE_P` is called once for every `TEST_P`. However, we
// only want to analyze the user provided schemes and those that are registered
// only once. Hence, this function keeping another static pointer to a vector
// which contains only the schemes under test.
//
// Without this additional step, when there are schemes available but the user
// only requests schemes that don't exist, first instantiation of the test would
// filter out all the user provided schemes (as they are not registered) but
// subsequent instantiations would return all registered schemes (since the
// vector with the user provided schemes is cleared).
static std::vector<std::string>* GetSchemesFromUserOrEnv() {
std::vector<std::string>* all_schemes = new std::vector<std::string>;
tensorflow::Status status =
tensorflow::Env::Default()->GetRegisteredFileSystemSchemes(&all_schemes);
tensorflow::Env::Default()->GetRegisteredFileSystemSchemes(all_schemes);
if (status.ok()) {
std::vector<std::string>* user_schemes = SchemeVector();
if (!user_schemes->empty()) {
auto is_registered_scheme = [&all_schemes](const auto& scheme) {
return std::find(all_schemes.begin(), all_schemes.end(), scheme) ==
all_schemes.end();
auto is_requested_scheme = [user_schemes](const auto& scheme) {
return std::find(user_schemes->begin(), user_schemes->end(), scheme) ==
user_schemes->end();
};
auto end = std::remove_if(user_schemes->begin(), user_schemes->end(),
is_registered_scheme);
user_schemes->erase(end, user_schemes->end());
return *user_schemes;
auto end = std::remove_if(all_schemes->begin(), all_schemes->end(),
is_requested_scheme);
all_schemes->erase(end, all_schemes->end());
}
}
// Next, try all schemes available
if (!all_schemes.empty()) return all_schemes;
}
return all_schemes;
}
// Fallback: no filesystems present, hence no tests
return std::vector<std::string>();
static std::vector<std::string> GetSchemes() {
static std::vector<std::string>* schemes = GetSchemesFromUserOrEnv();
return *schemes;
}
INSTANTIATE_TEST_SUITE_P(ModularFileSystem, ModularFileSystemTest,

View File

@ -1,35 +1,47 @@
# Experimental posix filesystem plugin.
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)
# Although this target results in a shared object that will be loaded at
# runtime, this target must be a `cc_library` instead of a `cc_binary`. Making
# it a `cc_binary` requires `linkshared = True`. In turn, this brings in several
# TensorFlow symbols under `tensorflow::` namespace, for which we have no ABI
# guarantees. Hence, in order to maintain ABI compatibility, this is marked as a
# `cc_library` for now and we will revisit in the future.
# TODO(mihaimaruseac): Determine if `cc_binary` makes more sense (when all
# filesystems are converted and BUILD files are refactored to be modular).
# TODO(b/144585140): The helpers should be separated into a different BUILD target
# but doing that would result in symbols not being visible when loading plugin.
# Revisit this once POSIX filesystem completely lands. See also the other TODO.
# This also has the unfortunate effect that both versions of copy_file get
# compiled, regardless of which one actually gets used!
# Filesystem implementation for POSIX environments: Linux, MacOS, Android, etc.
tf_cc_shared_object(
name = "libposix_filesystem.so",
framework_so = [],
linkstatic = False,
visibility = ["//visibility:public"],
deps = [":posix_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "posix_filesystem",
srcs = [
"posix_filesystem.cc",
"posix_filesystem_helper.cc",
"posix_filesystem_helper.h",
"copy_file.h",
] + select({
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
"//conditions:default": ["copy_file_portable.cc"],
}),
name = "posix_filesystem_impl",
srcs = ["posix_filesystem.cc"],
deps = [
":posix_filesystem_helper",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
],
)
# Library implementing helper functionality, so that the above only contains
# the API implementation for modular filesystems.
cc_library(
name = "posix_filesystem_helper",
srcs = ["posix_filesystem_helper.cc"],
hdrs = ["posix_filesystem_helper.h"],
deps = [":copy_file"],
)
# On Linux, we can copy files faster using `sendfile`. But not elsewhere.
# Hence, this private library to select which implementation to use.
cc_library(
name = "copy_file",
srcs = select({
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
"//conditions:default": ["copy_file_portable.cc"],
}),
hdrs = ["copy_file.h"],
)

View File

@ -181,7 +181,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
return;
}
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
TF_Tensor* result =
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
if (TF_GetCode(status) == TF_OK) {
*tensor = result;
}

View File

@ -170,6 +170,11 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
}
// --------------------------------------------------------------------------
void StringEncode(const char* src, size_t src_len, char* dst) {
dst = tensorflow::core::EncodeVarint64(dst, src_len);
memcpy(dst, src, src_len);
}
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
size_t dst_len, TF_Status* status) {
const size_t sz = TF_StringEncodedSize(src_len);
@ -185,8 +190,7 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
src_len, "-byte string"));
return 0;
}
dst = tensorflow::core::EncodeVarint64(dst, src_len);
memcpy(dst, src, src_len);
StringEncode(src, src_len, dst);
return sz;
}
@ -245,13 +249,11 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype,
namespace tensorflow {
// Non-static for testing.
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
*status = tensorflow::Status::OK();
if (!src.IsInitialized()) {
Set_TF_Status_from_Status(
status, FailedPrecondition(
"attempt to use a tensor with an uninitialized value"));
*status = FailedPrecondition(
"attempt to use a tensor with an uninitialized value");
return nullptr;
}
if (src.NumElements() == 0) {
@ -259,14 +261,13 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
}
if (src.dtype() == tensorflow::DT_RESOURCE) {
if (src.shape().dims() != 0) {
Set_TF_Status_from_Status(
status, InvalidArgument(
*status = InvalidArgument(
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
src.shape().DebugString(),
"). Please file a bug at "
"https://github.com/tensorflow/tensorflow/issues/new, "
"ideally with a "
"short code snippet that reproduces this error."));
"short code snippet that reproduces this error.");
return nullptr;
}
const string str =
@ -305,23 +306,15 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
*offsets = (dst - data_start);
offsets++;
const string& s = srcarray(i);
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
if (TF_GetCode(status) != TF_OK) {
Set_TF_Status_from_Status(
status,
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
srcarray.size(), "): ", TF_Message(status)));
delete[] base;
return nullptr;
}
const size_t consumed = TF_StringEncodedSize(s.size());
StringEncode(s.data(), s.size(), dst);
dst += consumed;
dst_len -= consumed;
}
if (dst != base + size) {
Set_TF_Status_from_Status(
status, InvalidArgument(
*status = InvalidArgument(
"invalid string tensor encoding (decoded ", (dst - base),
" bytes, but the tensor is encoded in ", size, " bytes"));
" bytes, but the tensor is encoded in ", size, " bytes");
delete[] base;
return nullptr;
}

View File

@ -259,6 +259,9 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
RunTest(x, x_init_value, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, MaxPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -271,6 +274,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
SetRandomValuesForMaxPooling<float>(&x_init_value);
RunTest(x, x_init_value, y, y_shape);
}
#endif
TEST_F(NNGradTest, AvgPoolGradHelper) {
TensorShape x_shape({1, 2, 2, 1});
@ -283,6 +287,9 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
RunTest(x, x_shape, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, AvgPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -293,6 +300,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
RunTest(x, x_shape, y, y_shape);
}
#endif
TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1});

View File

@ -75,8 +75,8 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
"@llvm//:support", # fixdeps: keep
"@llvm//:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:support", # fixdeps: keep
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
],
)
@ -104,11 +104,11 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@llvm//:aarch64_code_gen", # fixdeps: keep
"@llvm//:arm_code_gen", # fixdeps: keep
"@llvm//:powerpc_code_gen", # fixdeps: keep
"@llvm//:target",
"@llvm//:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
],
)
@ -205,9 +205,9 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
"@llvm-project//llvm:core",
"@llvm-project//llvm:support",
"@llvm-project//llvm:target",
],
)

View File

@ -407,6 +407,7 @@ def target_llvm_triple():
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
"//tensorflow:ios": "arm64-none-ios",
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:macos": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",

View File

@ -4,12 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilati
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
package(
default_visibility = [
":internal",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
],
default_visibility = [":internal"],
licenses = ["notice"], # Apache 2.0
)
@ -500,6 +495,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/hash/hash.h"

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"

View File

@ -2130,6 +2130,53 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
return Status::OK();
}
Status CopyOutsideCompilationConstNodes(
Graph* g, const string& outside_compilation_attr_name) {
for (Node* n : g->op_nodes()) {
if (!n->IsConstant() ||
!HasNodeAttr(n->def(), outside_compilation_attr_name)) {
continue;
}
std::vector<const Edge*> out_edges(n->out_edges().begin(),
n->out_edges().end());
bool has_non_oc_output = false;
for (const Edge* e : out_edges) {
if (!e->IsControlEdge() &&
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
has_non_oc_output = true;
break;
}
}
if (!has_non_oc_output) {
continue;
}
NodeDef copy_def = n->def();
copy_def.set_name(g->NewName(n->name()));
copy_def.mutable_attr()->erase(outside_compilation_attr_name);
Status s;
Node* copy_node = g->AddNode(copy_def, &s);
TF_RETURN_IF_ERROR(s);
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
g->AddControlEdge(e->src(), copy_node);
}
}
for (const Edge* e : out_edges) {
if (!e->IsControlEdge() &&
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
Node* dst = e->dst();
int dst_input = e->dst_input();
g->RemoveEdge(e);
g->AddEdge(copy_node, 0, dst, dst_input);
}
}
}
return Status::OK();
}
} // namespace
Status RewriteOutsideCompilationSubgraphFn::operator()(
@ -2279,6 +2326,10 @@ Status ExtractOutsideCompilationForFunction(
std::vector<string> outside_compilation_host_graphs;
std::vector<string> shape_inference_graphs_to_rewrite;
if (*has_outside_compilation) {
// Copy outside compilation Const nodes with non outside compilation users.
TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
fbody->graph, outside_compilation_attr_name));
// Find dependencies between outside compilation clusters.
TF_ASSIGN_OR_RETURN(auto cluster_deps,
OutsideCompilationClusterDependencies(

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/node_matchers.h"
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph_node_util.h"
namespace tensorflow {
namespace testing {

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/version.h"

View File

@ -17,7 +17,10 @@ limitations under the License.
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/util/dump_graph.h"
@ -39,7 +42,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
}
Status PropagateShapes(const Graph& graph,
Status PropagateShapes(Graph* graph,
const std::map<int, InferredShape>& arg_shapes,
const std::vector<BackEdgeHelper::BackEdge>& back_edges,
ShapeRefiner* shape_refiner) {
@ -54,7 +57,7 @@ Status PropagateShapes(const Graph& graph,
// shapes.
// TODO(phawkins): handle cyclic graphs.
std::vector<Node*> order;
GetReversePostOrder(graph, &order);
GetReversePostOrder(*graph, &order);
for (Node* n : order) {
// Ignore the status returned by the shape_refiner. We want the best effort
@ -99,6 +102,67 @@ Status PropagateShapes(const Graph& graph,
}
}
// Sometimes we have VariableShape nodes in while loop (after Enter nodes).
// They won't be constant-folded because TensorFlow constant folding does
// not handle Enter nodes (and thus does not handle any nodes after Enter
// nodes). We try to replace such VariableShape nodes with Const nodes here.
if (n->type_string() == "VariableShape") {
shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
shape_inference::ShapeHandle handle =
handle_shapes_and_types->at(0).shape;
TensorShapeProto shape_proto;
context->ShapeHandleToProto(handle, &shape_proto);
if (!shape_proto.unknown_rank()) {
NodeDef const_def;
const_def.set_op("Const");
Node* var_node;
TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
const_def.set_name(
graph->NewName(absl::StrCat("var_shape_", var_node->name())));
DataType dtype = n->output_type(0);
AddNodeAttr("dtype", dtype, &const_def);
TensorProto value;
value.set_dtype(dtype);
value.mutable_tensor_shape()->add_dim()->set_size(
shape_proto.dim_size());
for (const auto& dim : shape_proto.dim()) {
if (dtype == DT_INT32) {
value.add_int_val(dim.size());
} else {
value.add_int64_val(dim.size());
}
}
AddNodeAttr("value", value, &const_def);
for (auto const& attr : n->attrs()) {
if (*attr.first.begin() == '_') {
AddNodeAttr(attr.first, attr.second, &const_def);
}
}
Status s;
Node* const_node = graph->AddNode(const_def, &s);
TF_RETURN_IF_ERROR(s);
graph->AddControlEdge(var_node, const_node);
std::vector<const Edge*> out_edges(n->out_edges().begin(),
n->out_edges().end());
for (const Edge* e : out_edges) {
if (e->IsControlEdge()) {
graph->AddControlEdge(const_node, e->dst());
graph->RemoveEdge(e);
} else {
Node* dst = e->dst();
int dst_input = e->dst_input();
graph->RemoveEdge(e);
graph->AddEdge(const_node, 0, dst, dst_input);
}
}
}
}
}
// Merge node causes a loop so we remove NextIteration->Merge edge before
// performing shape inference. But removing those edges also prevents us
// from inferring output shape for Merge node (we need shapes for all its
@ -196,7 +260,7 @@ Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
// the shape inference is complete.
BackEdgeHelper back_edge;
TF_RETURN_IF_ERROR(back_edge.Remove(graph));
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes,
TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
back_edge.RemovedEdges(), &shape_refiner));
TF_RETURN_IF_ERROR(back_edge.Replace());

View File

@ -191,7 +191,7 @@ class XlaAssignVariableOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
data::IteratorGetNextAsOptionalOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \
data::IteratorGetNextSyncOp); \
data::IteratorGetNextOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \
.Device(DEVICE) \
.HostMemory("string_handle"), \

View File

@ -6,7 +6,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
package(
default_visibility = [
"//tensorflow/compiler/tf2xla:__subpackages__",
"@local_config_mlir//:friends",
"@llvm-project//mlir:friends",
],
licenses = ["notice"], # Apache 2.0
)
@ -30,8 +30,8 @@ cc_library(
hdrs = ["op_or_arg_name_mapper.h"],
deps = [
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)
@ -43,11 +43,14 @@ cc_library(
":passes",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm//:support",
"@local_config_mlir//:MlirOptLib",
"@local_config_mlir//:Pass",
"@local_config_mlir//:Support",
"@local_config_mlir//test:TestTransforms",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AffineDialectRegistration",
"@llvm-project//mlir:LoopDialectRegistration",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Support",
"@llvm-project//mlir/test:TestTransforms",
],
)
@ -80,9 +83,10 @@ cc_library(
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",
"@local_config_mlir//:AffineDialectRegistration",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
"//tensorflow/compiler/mlir/xla:xla_test_passes",
"@llvm-project//mlir:AffineOps",
"@llvm-project//mlir:QuantOps",
],
)
@ -92,7 +96,7 @@ cc_library(
hdrs = ["init_mlir.h"],
deps = [
"//tensorflow/core:lib",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)
@ -122,11 +126,11 @@ tf_cc_binary(
"//tensorflow/core:tensorflow",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
"@local_config_mlir//:TranslateClParser",
"@local_config_mlir//:Translation",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TranslateClParser",
"@llvm-project//mlir:Translation",
],
)

View File

@ -1,11 +1,11 @@
# MLIR dialects and utilities for TensorFlow, TensorFlow Lite and XLA.
This module contains the MLIR
([Multi-Level Intermediate Representation](https://github.com/tensorflow/mlir))
([Multi-Level Intermediate Representation](https://mlir.llvm.org))
dialects and utilities for
1. TensorFlow
2. XLA
3. TF Lite
See [MLIR repo](https://github.com/tensorflow/mlir) for complete documentation.
See [MLIR's website](https://mlir.llvm.org) for complete documentation.

View File

@ -10,7 +10,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths")
# Default values used by the test runner.
_default_test_file_exts = ["mlir", ".pbtxt", ".td"]
_default_driver = "@local_config_mlir//:run_lit.sh"
_default_driver = "@llvm-project//mlir:run_lit.sh"
_default_size = "small"
_default_tags = ["no_rocm"]
@ -50,16 +50,16 @@ def _run_lit_test(name, data, size, tags, driver, features):
native.py_test(
name = name,
srcs = ["@llvm//:lit"],
srcs = ["@llvm-project//llvm:lit"],
tags = tags,
args = [
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
] + features,
data = data + [
"//tensorflow/compiler/mlir:litfiles",
"@llvm//:FileCheck",
"@llvm//:count",
"@llvm//:not",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:count",
"@llvm-project//llvm:not",
],
size = size,
main = "lit.py",

View File

@ -1,6 +1,6 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
load(
"@local_config_mlir//:tblgen.bzl",
"//third_party/mlir:tblgen.bzl",
"gentbl",
)
@ -15,7 +15,7 @@ package(
package_group(
name = "friends",
includes = ["@local_config_mlir//:subpackages"],
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/...",
"//learning/brain/google/xla/...",
@ -28,7 +28,7 @@ filegroup(
srcs = [
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@local_config_mlir//:OpBaseTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
],
)
@ -48,7 +48,7 @@ gentbl(
"g3doc/tfl_ops.md",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tfl_ops.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
@ -63,11 +63,11 @@ gentbl(
"transforms/generated_prepare_tf.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/prepare_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_optimize_td_files",
],
@ -81,11 +81,11 @@ gentbl(
"transforms/generated_lower_static_tensor_list.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/tensorlist_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
@ -98,11 +98,11 @@ gentbl(
"transforms/generated_legalize_tf.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/legalize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
@ -115,11 +115,11 @@ gentbl(
"transforms/generated_optimize.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/optimize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
@ -132,11 +132,11 @@ gentbl(
"transforms/generated_quantize.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/quantize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
],
)
@ -148,11 +148,11 @@ gentbl(
"transforms/generated_post_quantize.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/post_quantize_patterns.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
],
)
@ -165,9 +165,9 @@ cc_library(
"utils/validators.h",
],
deps = [
"@local_config_mlir//:Dialect",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
],
)
@ -185,21 +185,21 @@ cc_library(
"transforms/passes.h",
"utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
"@local_config_mlir//:include/mlir/Transforms/InliningUtils.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:Dialect",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
@ -216,10 +216,10 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
@ -233,9 +233,9 @@ cc_library(
],
deps = [
":tensorflow_lite",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
],
)
@ -248,10 +248,10 @@ tf_cc_test(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
@ -292,14 +292,14 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/memory",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:Transforms",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
@ -317,12 +317,12 @@ cc_library(
":tensorflow_lite",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
@ -330,6 +330,7 @@ cc_library(
cc_library(
name = "tensorflow_lite_quantize",
srcs = [
"transforms/default_quant_params.cc",
"transforms/generated_post_quantize.inc",
"transforms/generated_quantize.inc",
"transforms/load_quantization_recipe.cc",
@ -348,13 +349,13 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
@ -376,7 +377,7 @@ genrule(
"utils/generated_op_quant_spec_getters.inc",
],
cmd = ("$(location //tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen) " +
"-I external/local_config_mlir/include " +
"-I external/llvm-project/mlir/include " +
"-I external/org_tensorflow " +
"$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"),
tools = ["//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen"],
@ -390,7 +391,7 @@ cc_library(
],
deps = [
":tensorflow_lite",
"@local_config_mlir//:IR",
"@llvm-project//mlir:IR",
],
alwayslink = 1,
)
@ -401,9 +402,9 @@ tf_native_cc_binary(
"operator_converter_gen.cc",
],
deps = [
"@llvm//:support",
"@llvm//:tablegen",
"@local_config_mlir//:TableGen",
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//mlir:TableGen",
],
)
@ -436,12 +437,17 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:TransformUtils",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils",
],
)
@ -464,7 +470,7 @@ cc_library(
],
deps = [
"//tensorflow/lite/core/api",
"@local_config_mlir//:IR",
"@llvm-project//mlir:IR",
],
)
@ -501,6 +507,7 @@ cc_library(
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util",
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/versioning:op_version",
"@com_google_absl//absl/base",
@ -509,14 +516,14 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"@local_config_mlir//:StandardDialectRegistration",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:Translation",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:StandardDialectRegistration",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
],
alwayslink = 1,
)
@ -525,7 +532,7 @@ tf_cc_binary(
name = "flatbuffer_translate",
deps = [
":flatbuffer_translate_lib",
"@local_config_mlir//:MlirTranslateMain",
"@llvm-project//mlir:MlirTranslateMain",
],
)
@ -538,7 +545,7 @@ cc_library(
"tf_tfl_translate_cl.h",
],
deps = [
"@llvm//:support",
"@llvm-project//llvm:support",
],
alwayslink = 1,
)
@ -550,7 +557,7 @@ cc_library(
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)
@ -578,9 +585,9 @@ tf_cc_binary(
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
@ -596,10 +603,10 @@ tf_cc_binary(
"//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
],
)
@ -622,12 +629,12 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"@local_config_mlir//:Transforms",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Transforms",
],
)
@ -654,15 +661,15 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",
"@local_config_mlir//:Support",
"@local_config_mlir//:Transforms",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <cstdarg>
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:llvm-project
#include "tensorflow/lite/core/api/error_reporter.h"
namespace tflite {

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include <cctype>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <string>
@ -43,24 +44,24 @@ limitations under the License.
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Translation.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Translation.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
@ -103,12 +104,26 @@ using llvm::cl::opt;
// Commandline flag to enable the control of flatbuffer import.
bool use_external_constant;
// Commandline flag to enable graph pruning.
bool experimental_prune_unreachable_nodes_unconditionally;
// NOLINTNEXTLINE
static opt<bool, true> use_external_constant_flag(
"use-external-constant",
llvm::cl::desc("Use external constant during flatbuffer import"),
llvm::cl::location(use_external_constant), llvm::cl::init(false));
// TODO(b/147111261): After the importer supports generic custom ops, we should
// change the flag to a more lightwise flag, e.g.
// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
// the operations.
// NOLINTNEXTLINE
static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
"experimental-prune-unreachable-nodes-unconditionally",
llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
llvm::cl::init(false));
namespace {
bool IsScalar(const TensorT& tensor) {
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
@ -217,7 +232,7 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
// min/max stats is just for comments, so ignore it.
if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
// If the result isn't float and unquantizable, the min/max is ignored.
if (!res->getType()
if (!res.getType()
.cast<mlir::ShapedType>()
.getElementType()
.isa<mlir::FloatType>()) {
@ -255,10 +270,23 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
}
StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
// TODO(krzysd) Support custom ops
// TODO(b/143872630): Support custom ops
if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) {
return errors::Unimplemented("unsupported custom operation: ",
opcode.custom_code);
// Adding some custom op supported on GPU.
const absl::string_view custom_name = opcode.custom_code;
if (custom_name == "MaxPoolingWithArgmax2D") {
return std::string("tfl.max_pooling_with_argmax_2d");
}
if (custom_name == "Convolution2DTransposeBias") {
return std::string("tfl.convolution_2d_transpose_bias");
}
if (custom_name == "MaxUnpooling2D") {
return std::string("tfl.max_unpooling_2d");
}
// Use an unsupported op name instead of throwing an error here in case the
// op is pruned during the import.
return std::string(
llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str());
}
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
return std::string("tf.If");
@ -495,6 +523,13 @@ bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
}
}
// Returns true if this is a custom op.
bool IsCustomOp(const std::string& op_name) {
return op_name == "tfl.max_pooling_with_argmax_2d" ||
op_name == "tfl.max_unpooling_2d" ||
op_name == "tfl.convolution_2d_transpose_bias";
}
// TODO(krzysd) Handle function calls
StatusOr<Operation*> ConvertOp(
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
@ -557,7 +592,15 @@ StatusOr<Operation*> ConvertOp(
}
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
if (IsCustomOp(op_name)) {
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
builder, loc, &attrs);
if (!status.ok()) {
return emitError(loc, status.ToString()), status;
}
} else {
mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
}
op_state.addAttributes(attrs);
// Handle the conversion from subgraph index to functions for If and While
@ -619,6 +662,49 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
}
// Given a list of output indices, traverses the subgraph and returns the set of
// ops that are ancestors of the output tensors.
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
// Create a map from tensor index to defining op.
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
for (const auto& op : subgraph.operators) {
for (int32_t output : op->outputs) {
defining_op[output] = op.get();
}
}
std::vector<const tflite::OperatorT*> queue;
for (int32_t output : output_indices) {
if (auto& op = defining_op[output]) {
queue.push_back(op);
} else {
return errors::InvalidArgument("Output tensor doesn't have defining op");
}
}
// Traverse the graph towards inputs.
absl::flat_hash_set<const tflite::OperatorT*> visited;
while (!queue.empty()) {
const tflite::OperatorT* op = queue.back();
queue.pop_back();
if (!visited.insert(op).second) {
// The node has already been visited.
continue;
}
for (int32_t input : op->inputs) {
// Input tensor may not have a defining op in case it is a subgraph input
// or a constant tensor.
if (auto& op = defining_op[input]) {
queue.push_back(op);
}
}
}
return visited;
}
// Build a FuncOp from a tflite SubGraph
// The op_names are a mapping from indexes into the TFLite operators array to
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
@ -635,7 +721,8 @@ StatusOr<FuncOp> ConvertSubgraph(
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
Location base_loc, Builder builder,
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
bool use_external_constant) {
bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
llvm::SmallVector<mlir::Type, 2> ret_types;
llvm::SmallVector<mlir::Type, 4> input_types;
@ -731,8 +818,19 @@ StatusOr<FuncOp> ConvertSubgraph(
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
}
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
if (experimental_prune_unreachable_nodes_unconditionally) {
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
PruneSubgraph(subgraph, func_outputs));
}
// Construct MLIR operators from TFLite operators
for (auto& op : subgraph.operators) {
if (experimental_prune_unreachable_nodes_unconditionally &&
!pruned_subgraph_ops.contains(op)) {
continue;
}
for (auto input_num : op->inputs) {
// The operators in a graph are topologically sorted
// and so if no previous operation has produced a tensor
@ -837,7 +935,8 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
OwningModuleRef tflite::FlatBufferToMlir(
absl::string_view buffer, MLIRContext* context, Location base_loc,
const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant) {
bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
auto model_ptr =
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
if (nullptr == model_ptr) {
@ -892,7 +991,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
// TODO(b/131175224,b/132239787) Support multiple entry points
builder, ordered_output_arrays,
/*is_entry_point=*/e.index() == 0,
/*use_external_constant=*/use_external_constant);
/*use_external_constant=*/use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
if (!func_or_error.ok()) {
return emitError(base_loc, "could not translate function ")
<< subgraph->name,
@ -905,9 +1005,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
return OwningModuleRef(module);
}
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
MLIRContext* context,
bool use_external_constant) {
static OwningModuleRef FlatBufferFileToMlirTrans(
llvm::SourceMgr* source_mgr, MLIRContext* context,
bool use_external_constant,
bool experimental_prune_unreachable_nodes_unconditionally) {
const llvm::MemoryBuffer* input =
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
std::string error;
@ -924,12 +1025,14 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
return tflite::FlatBufferToMlir(
absl::string_view(input->getBufferStart(), input->getBufferSize()),
context, loc, outputs, use_external_constant);
context, loc, outputs, use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
}
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
"tflite-flatbuffer-to-mlir",
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
return FlatBufferFileToMlirTrans(&source_mgr, context,
use_external_constant);
return FlatBufferFileToMlirTrans(
&source_mgr, context, use_external_constant,
experimental_prune_unreachable_nodes_unconditionally);
});

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
#include "absl/strings/string_view.h"
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
namespace tflite {
// Converts a TFLite flatbuffer stored in `buffer` to a MLIR module
@ -31,11 +31,14 @@ namespace tflite {
// on failure, and more specific errors will be emitted via the context.
// If `use_external_constant` is true, it will create `tfl.external_const`
// instead of `tfl.const`.
// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that
// are not ancestors of the output nodes will be pruned.
mlir::OwningModuleRef FlatBufferToMlir(
absl::string_view buffer, mlir::MLIRContext* context,
mlir::Location base_loc,
const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant = false);
bool use_external_constant = false,
bool experimental_prune_unreachable_nodes_unconditionally = false);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_

View File

@ -17,15 +17,45 @@ limitations under the License.
#include <vector>
#include "absl/strings/str_cat.h"
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace {
using ::tensorflow::Status;
using ::tensorflow::errors::InvalidArgument;
using ::xla::StatusOr;
StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
mlir::Builder builder,
mlir::Location loc) {
auto padding = tflite::Padding::Padding_VALID;
if (pad_params == TfLitePadding::kTfLitePaddingSame) {
padding = tflite::Padding_SAME;
} else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
padding = tflite::Padding_VALID;
} else {
return InvalidArgument(
absl::StrCat("Invalid padding type", std::to_string(pad_params)));
}
const char* option_name = tflite::EnumNamePadding(padding);
return builder.getStringAttr(option_name);
}
} // namespace
// TODO(jpienaar): This is a placeholder. This should be done in more efficient
// way when part of the translation of module.
static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
@ -212,5 +242,44 @@ static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
return builder.getStringAttr(option_name);
}
Status mlir::CustomOptionsToAttributes(
const std::string& op_name, const std::vector<uint8_t>& custom_options,
mlir::Builder builder, mlir::Location loc,
llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
if (op_name == "tfl.max_pooling_with_argmax_2d" ||
op_name == "tfl.max_unpooling_2d") {
auto* pool_params =
reinterpret_cast<const TfLitePoolParams*>(custom_options.data());
TF_ASSIGN_OR_RETURN(auto padding_attribute,
GetPaddingAttr(pool_params->padding, builder, loc));
attributes->emplace_back(
builder.getNamedAttr("padding", padding_attribute));
attributes->emplace_back(builder.getNamedAttr(
"stride_h", builder.getI32IntegerAttr(pool_params->stride_height)));
attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
attributes->emplace_back(builder.getNamedAttr(
"filter_h", builder.getI32IntegerAttr(pool_params->filter_height)));
attributes->emplace_back(builder.getNamedAttr(
"filter_w", builder.getI32IntegerAttr(pool_params->filter_width)));
return Status::OK();
} else if (op_name == "tfl.convolution_2d_transpose_bias") {
auto* conv_params = reinterpret_cast<const TfLiteTransposeConvParams*>(
custom_options.data());
TF_ASSIGN_OR_RETURN(auto padding_attribute,
GetPaddingAttr(conv_params->padding, builder, loc));
attributes->emplace_back(
builder.getNamedAttr("padding", padding_attribute));
attributes->emplace_back(builder.getNamedAttr(
"stride_h", builder.getI32IntegerAttr(conv_params->stride_height)));
attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(conv_params->stride_width)));
return Status::OK();
}
return InvalidArgument(absl::StrCat("invalid custom op type: ", op_name));
}
// Pull in FlatBuffer writers for TFLite generated using TableGen
#include "tensorflow/compiler/mlir/lite/operator_converters.inc"

View File

@ -26,9 +26,10 @@ limitations under the License.
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {
@ -45,7 +46,7 @@ llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
flatbuffers::FlatBufferBuilder *fbb);
// Populate the array of mlir::NamedAttributes corresponding to the given
// Populates the array of mlir::NamedAttributes corresponding to the given
// tflite::FlatbufferOptionsUnion.
// We use an out parameter per LLVM convention
void BuiltinOptionsToAttributes(
@ -53,6 +54,15 @@ void BuiltinOptionsToAttributes(
// NOLINTNEXTLINE
llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes);
// Populates the array of mlir::NamedAttributes corresponding to the given
// custom_options.
// We use an out parameter per LLVM convention
tensorflow::Status CustomOptionsToAttributes(
const std::string &op_name, const std::vector<uint8_t> &custom_options,
mlir::Builder builder,
// NOLINTNEXTLINE
Location loc, llvm::SmallVectorImpl<mlir::NamedAttribute> *attributes);
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_

View File

@ -41,19 +41,19 @@ limitations under the License.
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Translation.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Translation.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
@ -71,6 +71,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/versioning/op_version.h"
@ -218,6 +219,13 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>();
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
}
case mlir::TF::TensorFlowTypes::RESOURCE: {
// Treat tf.resource values as integer values in flatbuffer.
// TODO(b/146131919): Maybe need to have a detailed design for supporting
// other resource types beyonds hash table resources and resource
// variables.
return tflite::TensorType_INT32;
}
default:
// TFLite export fills FLOAT32 for unknown data types. Returning an error
// for now for safety and this could be revisited when required.
@ -233,17 +241,17 @@ static bool IsConst(Operation* op) {
template <typename T>
static bool HasValidTFLiteType(Value value, T& error_handler) {
// None type is allowed to represent unspecified operands.
if (value->getType().isa<NoneType>()) return true;
if (value.getType().isa<NoneType>()) return true;
auto type = value->getType().dyn_cast<TensorType>();
auto type = value.getType().dyn_cast<TensorType>();
if (!type) {
if (auto op = value->getDefiningOp()) {
if (auto op = value.getDefiningOp()) {
error_handler.emitError()
<< '\'' << op << "' should produce value of tensor type instead of "
<< value->getType();
<< value.getType();
return false;
}
error_handler.emitError("expected tensor type, got ") << value->getType();
error_handler.emitError("expected tensor type, got ") << value.getType();
return false;
}
@ -282,7 +290,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
for (auto arg : bb.getArguments()) {
if (!HasValidTFLiteType(arg, fn))
return fn.emitError("invalid TFLite type: ") << arg->getType(), false;
return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
}
// Verify that all operations except the terminator have exactly one
@ -292,7 +300,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
for (auto result : inst.getResults()) {
if (!HasValidTFLiteType(result, inst))
return fn.emitError("invalid TFLite type: ") << result->getType(),
return fn.emitError("invalid TFLite type: ") << result.getType(),
false;
}
}
@ -317,6 +325,48 @@ static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
return std::move(status_or_node_def.ValueOrDie());
}
// Converts a mlir padding StringRef to TfLitePadding.
// Returns llvm::None if conversion fails.
static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
llvm::StringRef padding) {
const tflite::Padding padding_attr =
std::move(llvm::StringSwitch<tflite::Padding>(padding)
.Case("SAME", tflite::Padding_SAME)
.Case("VALID", tflite::Padding_VALID));
if (padding_attr == tflite::Padding_SAME) {
return kTfLitePaddingSame;
}
if (padding_attr == tflite::Padding_VALID) {
return kTfLitePaddingValid;
}
return inst->emitOpError() << "Invalid padding attribute: " << padding,
llvm::None;
}
// Extracts TfLitePoolParams from a TFL custom op.
// Template parameter, TFLOp, should be a TFL custom op containing attributes
// generated from TfLitePoolParams.
// Returns llvm::None if conversion fails.
template <typename TFLOp>
static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
TFLOp op) {
TfLitePoolParams pool_params;
pool_params.stride_height = op.stride_h().getSExtValue();
pool_params.stride_width = op.stride_w().getSExtValue();
pool_params.filter_height = op.filter_h().getSExtValue();
pool_params.filter_width = op.filter_w().getSExtValue();
const auto padding = GetTflitePadding(inst, op.padding());
if (padding) {
pool_params.padding = *padding;
pool_params.activation = kTfLiteActNone;
pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
return pool_params;
}
return llvm::None;
}
namespace {
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
@ -375,9 +425,31 @@ class Translator {
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Builds custom operators.
// Templated on a) data type of custom_option to be stored into flatbuffer,
// and b) TFL custom op type.
template <typename CustomOptionType, typename TFLOp>
BufferOffset<tflite::Operator> BuildCustomOperator(
const CustomOptionType& custom_option, const std::string& opcode_name,
TFLOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>>
BuildConvolution2DTransposeBiasOperator(
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>> BuildMaxPoolingWithArgMax2DOperator(
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>> BuildMaxUnpooling2DOperator(
Operation* inst, mlir::TFL::MaxUnpooling2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
@ -504,7 +576,7 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
Value value, const std::string& name, unsigned buffer_idx) {
auto type = value->getType().cast<TensorType>();
auto type = value.getType().cast<TensorType>();
// TFLite requires tensor shape only for the inputs and constants.
// However, we output all known shapes for better round-tripping
@ -516,7 +588,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
return mlir::emitError(
value->getLoc(),
value.getLoc(),
"result shape dimensions out of 32 bit int type range");
return mlir::success();
@ -528,7 +600,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
} else if (auto* inst = value->getDefiningOp()) {
} else if (auto* inst = value.getDefiningOp()) {
if (IsConst(inst)) {
// Const op can have a result of dynamic shaped type (e.g. due to constant
// folding), but we can still derive the shape of a constant tensor for
@ -571,7 +643,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
// marked as a stateful. If so, set the tensor's is_variable as true
// This is v1 ref variable semantics in the TFLite runtime.
bool is_variable = false;
for (auto& use : value->getUses()) {
for (auto& use : value.getUses()) {
is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
if (is_variable) {
break;
@ -615,19 +687,72 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
builtin_options);
}
template <typename CustomOptionType, typename TFLOp>
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
const CustomOptionType& custom_option, const std::string& opcode_name,
TFLOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
std::vector<uint8_t> custom_option_vector(sizeof(CustomOptionType));
memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType));
auto opcode_index =
GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM);
return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
/*builtin_options=*/0,
builder_.CreateVector<uint8_t>(custom_option_vector),
tflite::CustomOptionsFormat_FLEXBUFFERS);
}
BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
float tolerance = op.tolerance().convertToFloat();
std::vector<uint8_t> custom_options(sizeof(float));
memcpy(custom_options.data(), &tolerance, sizeof(float));
auto opcode_index =
GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
/*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_options),
tflite::CustomOptionsFormat_FLEXBUFFERS);
return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results);
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildConvolution2DTransposeBiasOperator(
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
TfLiteTransposeConvParams conv_params;
conv_params.stride_height = op.stride_h().getSExtValue();
conv_params.stride_width = op.stride_w().getSExtValue();
const auto padding = GetTflitePadding(inst, op.padding());
if (padding) {
conv_params.padding = *padding;
return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op,
operands, results);
}
return llvm::None;
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildMaxPoolingWithArgMax2DOperator(
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
const auto pool_params = GetTflitePoolParams(inst, op);
if (pool_params) {
return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op,
operands, results);
}
return llvm::None;
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildMaxUnpooling2DOperator(Operation* inst,
mlir::TFL::MaxUnpooling2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
const auto pool_params = GetTflitePoolParams(inst, op);
if (pool_params) {
return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands,
results);
}
return llvm::None;
}
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
@ -769,6 +894,20 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
return BuildNumericVerifyOperator(verify_op, operands, results);
}
if (auto conv_transpose_bias_op =
dyn_cast<mlir::TFL::Convolution2DTransposeBiasOp>(inst)) {
return BuildConvolution2DTransposeBiasOperator(
inst, conv_transpose_bias_op, operands, results);
}
if (auto max_pooling_with_arg_max_op =
dyn_cast<mlir::TFL::MaxPoolingWithArgMax2DOp>(inst)) {
return BuildMaxPoolingWithArgMax2DOperator(
inst, max_pooling_with_arg_max_op, operands, results);
}
if (auto max_unpooling_op = dyn_cast<mlir::TFL::MaxUnpooling2DOp>(inst)) {
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
results);
}
inst->emitOpError("is not a supported TFLite op");
return llvm::None;
}
@ -923,7 +1062,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// on failure.
auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
// NoneType represents optional and may be skipped here.
if (value->getType().isa<NoneType>()) {
if (value.getType().isa<NoneType>()) {
return true;
}
@ -936,7 +1075,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// make the Buffer empty apart from setting the buffer_idx=0 in the Tensor.
// This does not seem to affect runtime behavior for RNN/LSTM, but would be
// good for reducing memory footprint.
if (auto* inst = value->getDefiningOp()) {
if (auto* inst = value.getDefiningOp()) {
auto buffer_or = BuildBuffer(inst);
if (!buffer_or) return false;
buffers_.push_back(*buffer_or);
@ -976,7 +1115,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
std::vector<int32_t> operands;
operands.reserve(inst.getNumOperands());
for (auto operand : inst.getOperands()) {
if (operand->getType().isa<NoneType>())
if (operand.getType().isa<NoneType>())
operands.push_back(kTfLiteOptionalTensor);
else
operands.push_back(tensor_index_map.lookup(operand));

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
namespace tflite {

View File

@ -25,17 +25,17 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
@ -304,11 +304,11 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
Value rhs) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!result_type)
emitError(result.location)
<< "non-broadcastable operands: " << lhs->getType() << " and "
<< rhs->getType();
<< "non-broadcastable operands: " << lhs.getType() << " and "
<< rhs.getType();
result.addOperands({lhs, rhs});
// Comparison binary ops always return i1 tensor.
if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
@ -324,12 +324,12 @@ void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
Value lhs, Value rhs,
StringAttr fused_activation_function) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!result_type)
emitError(result.location)
<< "non-broadcastable operands: " << lhs->getType() << " and "
<< rhs->getType();
<< "non-broadcastable operands: " << lhs.getType() << " and "
<< rhs.getType();
result.addOperands({lhs, rhs});
result.addAttribute("fused_activation_function", fused_activation_function);
@ -358,7 +358,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
namespace {
int64_t GetConcatenationOpAxis(ConcatenationOp op) {
auto output_type = op.output()->getType().cast<RankedTensorType>();
auto output_type = op.output().getType().cast<RankedTensorType>();
int64_t axis = op.axis().getSExtValue();
if (axis < 0) axis += output_type.getRank();
return axis;
@ -452,7 +452,7 @@ LogicalResult VerifyConcatenationOpTypes(Operation *op,
}
LogicalResult Verify(ConcatenationOp op) {
auto output_type = op.output()->getType().dyn_cast<RankedTensorType>();
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
// If the output type is unranked, there is nothing else to be verified.
if (!output_type) return success();
@ -463,7 +463,7 @@ LogicalResult Verify(ConcatenationOp op) {
SmallVector<TensorType, 4> operand_types;
for (Value operand : op.values())
operand_types.push_back(operand->getType().cast<TensorType>());
operand_types.push_back(operand.getType().cast<TensorType>());
return VerifyConcatenationOpTypes(op.getOperation(), output_type,
operand_types, axis);
@ -520,7 +520,7 @@ DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
if (fused_activation_function() == "NONE") {
if (auto output_type = output()->getType().dyn_cast<RankedTensorType>()) {
if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
const int64_t axis = GetConcatenationOpAxis(*this);
if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
return ConstFoldConcatenateOpDense(operands, output_type, axis);
@ -530,7 +530,7 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
// Remove all empty values.
SmallVector<Value, 4> non_empty_values;
for (Value value : this->values()) {
const auto shaped_type = value->getType().cast<ShapedType>();
const auto shaped_type = value.getType().cast<ShapedType>();
if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
continue;
}
@ -559,8 +559,8 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
LogicalResult Verify(FullyConnectedOp op) {
ShapedType input_type = op.input()->getType().cast<ShapedType>();
ShapedType filter_type = op.filter()->getType().cast<ShapedType>();
ShapedType input_type = op.input().getType().cast<ShapedType>();
ShapedType filter_type = op.filter().getType().cast<ShapedType>();
if (filter_type.hasRank() && filter_type.getRank() != 2) {
return op.emitOpError("expect 2d filter, got ") << filter_type;
}
@ -582,7 +582,7 @@ LogicalResult Verify(FullyConnectedOp op) {
// format.
if (op.weights_format() == "DEFAULT") {
ShapedType output_type =
(*op.output().begin())->getType().cast<ShapedType>();
(*op.output().begin()).getType().cast<ShapedType>();
if (!output_type.hasStaticShape()) {
return mlir::success();
}
@ -610,8 +610,8 @@ LogicalResult Verify(FullyConnectedOp op) {
static void BuildGatherOp(Builder *builder, OperationState &result,
Value params, Value indices, IntegerAttr axis) {
auto params_type = params->getType().cast<TensorType>();
auto indices_type = indices->getType().cast<TensorType>();
auto params_type = params.getType().cast<TensorType>();
auto indices_type = indices.getType().cast<TensorType>();
// If params/indices is unranked, then output is unranked.
if (!params_type.hasRank() || !indices_type.hasRank())
@ -705,7 +705,7 @@ static LogicalResult Verify(PackOp op) {
return op.emitOpError("input count should match 'values_count' attribute");
Value operand0 = op.getOperand(0);
auto input_type = operand0->getType().cast<ShapedType>();
auto input_type = operand0.getType().cast<ShapedType>();
// Check axis bounds.
if (input_type.hasRank()) {
@ -718,7 +718,7 @@ static LogicalResult Verify(PackOp op) {
// Make sure all inputs have the same shape and element type.
// TODO(rahulsp): Simplify once b/135032064 is fixed.
for (Value operand : op.getOperands()) {
auto other_type = operand->getType().cast<ShapedType>();
auto other_type = operand.getType().cast<ShapedType>();
if (input_type != other_type)
return op.emitOpError("operands should be of the same type. got ")
<< input_type << ", " << other_type;
@ -732,9 +732,9 @@ static LogicalResult Verify(PackOp op) {
//===----------------------------------------------------------------------===//
static LogicalResult Verify(PReluOp op) {
auto input_type = op.input()->getType().cast<ShapedType>();
auto alpha_type = op.alpha()->getType().cast<ShapedType>();
auto output_type = op.output()->getType().cast<ShapedType>();
auto input_type = op.input().getType().cast<ShapedType>();
auto alpha_type = op.alpha().getType().cast<ShapedType>();
auto output_type = op.output().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
if (input_type.getRank() != alpha_type.getRank() + 1) {
@ -783,13 +783,13 @@ struct RemoveAdjacentReshape : public RewritePattern {
PatternMatchResult match(Operation *op) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = thisOp.getOperand(0)->getDefiningOp();
auto prevOp = thisOp.getOperand(0).getDefiningOp();
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto thisOp = cast<ReshapeOp>(op);
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0)->getDefiningOp());
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
// Replace
// %1 = "tfl.reshape"(%0, %shape0)
@ -807,7 +807,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
// Remove identity reshape with both static result and input shape.
auto result_type = getType().cast<ShapedType>();
auto input_type = getOperand(0)->getType().cast<ShapedType>();
auto input_type = getOperand(0).getType().cast<ShapedType>();
if (result_type.hasStaticShape() && result_type == input_type) {
return getOperand(0);
}
@ -865,7 +865,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TFL::PackOp pack_op = cast<TFL::PackOp>(op);
Operation *first_input = pack_op.getOperand(0)->getDefiningOp();
Operation *first_input = pack_op.getOperand(0).getDefiningOp();
if (!first_input) return matchFailure();
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
if (!input_unpack_op) return matchFailure();
@ -905,9 +905,9 @@ void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SliceOp op) {
auto input_type = op.input()->getType().cast<ShapedType>();
auto begin_type = op.begin()->getType().cast<ShapedType>();
auto size_type = op.size()->getType().cast<ShapedType>();
auto input_type = op.input().getType().cast<ShapedType>();
auto begin_type = op.begin().getType().cast<ShapedType>();
auto size_type = op.size().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
size_type.hasStaticShape()) {
if (input_type.getRank() != begin_type.getNumElements()) {
@ -995,7 +995,7 @@ static void BuildTopKOp(Builder *builder, OperationState &result, Value input,
// TODO(jpienaar): This should use a helper function.
const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
auto val_type = input->getType().cast<TensorType>();
auto val_type = input.getType().cast<TensorType>();
// If value is unranked, then so is results.
if (!val_type.hasRank())
return TFL::TopKV2Op::build(
@ -1035,7 +1035,7 @@ struct DropFakeQuant : public RewritePattern {
// If all the users of this op have valid "minmax" attributes, it is matched
// and can be removed.
auto fakeQuantOp = cast<FakeQuantOp>(op);
for (auto *operand : fakeQuantOp.getResult()->getUsers())
for (auto *operand : fakeQuantOp.getResult().getUsers())
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
return matchSuccess();
@ -1102,7 +1102,7 @@ static LogicalResult VerifySplitOpOutputTypes(
for (int64_t i = 0; i < num_splits; ++i) {
auto expected_output_type = get_expected_output_type(i);
Value output = op->getResult(i);
auto output_type = output->getType().dyn_cast<RankedTensorType>();
auto output_type = output.getType().dyn_cast<RankedTensorType>();
if (!output_type || output_type != expected_output_type)
return op->emitOpError()
<< "output #" << i << " should be " << expected_output_type;
@ -1121,7 +1121,7 @@ static LogicalResult Verify(SplitOp op) {
if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue();
@ -1157,7 +1157,7 @@ static LogicalResult Verify(SplitVOp op) {
if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue();
@ -1177,8 +1177,7 @@ static LogicalResult Verify(SplitVOp op) {
return success();
if (size_splits_attr.getNumElements() != num_splits) {
auto size_splits_type =
op.size_splits()->getType().cast<RankedTensorType>();
auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
RankedTensorType expected_size_splits_type =
RankedTensorType::get({num_splits}, size_splits_type.getElementType());
return op.emitOpError("'size_splits' should be ")
@ -1414,7 +1413,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
}
// Also fold if `input` has a known rank.
auto input_type = input()->getType().cast<ShapedType>();
auto input_type = input().getType().cast<ShapedType>();
// Do not fold if rank is zero because the TFLite converter doesn't
// distinguish between unranked input and scalar input due to b/138865275.
// TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
@ -1445,18 +1444,18 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
static void BuildSelectV2Op(Builder *builder, OperationState &result,
Value cond, Value x, Value y) {
auto operand_type =
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
if (!operand_type)
emitError(result.location) << "non-broadcastable operands: " << x->getType()
<< " and " << y->getType();
emitError(result.location) << "non-broadcastable operands: " << x.getType()
<< " and " << y.getType();
bool has_static_cond_shape = false;
bool has_static_operand_shape = false;
ArrayRef<int64_t> cond_shape;
ArrayRef<int64_t> operand_shape;
if (auto shaped_type = cond->getType().dyn_cast<ShapedType>()) {
if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
if (shaped_type.hasStaticShape()) {
has_static_cond_shape = true;
cond_shape = shaped_type.getShape();
@ -1474,12 +1473,12 @@ static void BuildSelectV2Op(Builder *builder, OperationState &result,
!OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
broadcastedShape)) {
emitError(result.location) << "non-broadcastable operands: " << operand_type
<< " and " << cond->getType();
<< " and " << cond.getType();
}
result.addOperands({cond, x, y});
auto elementType = x->getType().dyn_cast<ShapedType>().getElementType();
auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
if (has_static_cond_shape && has_static_operand_shape) {
result.types.push_back(
RankedTensorType::get(broadcastedShape, elementType));
@ -1571,9 +1570,8 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TransposeConvOp op) {
ShapedType output_type = op.output()->getType().cast<ShapedType>();
ShapedType output_shape_type =
op.output_shape()->getType().cast<ShapedType>();
ShapedType output_type = op.output().getType().cast<ShapedType>();
ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
if (output_type.getRank() != output_shape_type.getDimSize(0)) {
return op.emitOpError(llvm::formatv(
@ -1679,9 +1677,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
}
static LogicalResult Verify(TransposeOp op) {
auto input_type = op.x()->getType().cast<ShapedType>();
auto perm_type = op.perm()->getType().cast<ShapedType>();
auto output_type = op.y()->getType().cast<ShapedType>();
auto input_type = op.x().getType().cast<ShapedType>();
auto perm_type = op.perm().getType().cast<ShapedType>();
auto output_type = op.y().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
if (perm_type.getNumElements() != input_type.getRank()) {
return op.emitOpError(

View File

@ -18,15 +18,15 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/Traits.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Dialect.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"

View File

@ -135,7 +135,7 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
//===----------------------------------------------------------------------===//
class TFL_OperandIsUnrankedPred<int n> :
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">;
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
// TODO: Some of these could be generalized and/or moved to more general
// location.
@ -144,38 +144,38 @@ class TFL_OperandHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
").getType().cast<ShapedType>().getRank() == " # m>]>>;
// Returns true if the n-th operand is ranked and has rank dim.
class TFL_OperandHasKnownRank<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() == "
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() == "
# dim>]>;
// True if operand n is ranked and has a rank > dim.
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > "
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
# dim>]>;
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
TFL_OperandIsRankedAndHasDimPred<n, dim>,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()"
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
".getShape()[" # dim # " ] == " # size>]>;
// Returns true if the n-th operand has unknown rank or at least rank m.
class TFL_OperandHasAtleastRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
Or<[CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() >= " # m>]>>;
").getType().cast<ShapedType>().getRank() >= " # m>]>>;
class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
CPred<"$_op.getOperand(" # x #
")->getType().cast<ShapedType>().getRank() == "
").getType().cast<ShapedType>().getRank() == "
"$_op.getOperand(" # y #
")->getType().cast<ShapedType>().getShape()[0]">>;
").getType().cast<ShapedType>().getShape()[0]">>;
class TFL_Operand0DOr1ElementTensor<int x> :
PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element",
@ -195,7 +195,7 @@ class TFL_OperandHasRankLessThan<int n, int m> :
PredOpTrait<"operand " # n # " is maximum " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() <= " # m>]>>;
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
// This is a quantization-aware version of TCresVTEtIsSameAsOp
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
@ -227,7 +227,7 @@ def TFL_BroadcastableBinaryBuilder : OpBuilder<
"Builder *builder, OperationState &result, Value lhs, Value rhs",
[{
auto resultType =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!resultType)
mlir::emitError(result.location, "non-broadcastable operands");
result.addOperands({lhs, rhs});
@ -427,6 +427,33 @@ def TFL_TransposeConvOp:
let verifier = [{ return Verify(*this); }];
}
def TFL_Convolution2DTransposeBiasOp :
Op<TFL_Dialect, "convolution_2d_transpose_bias", [NoSideEffect]> {
let summary = " Transpose convolution with bias operator";
let description = [{
Performs transpose convolution operation on inputs,
with the option of adding a bias.
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the filter weight tensor
`inputs[2]`: optional: the bias tensor
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$filter,
TFL_TensorOfOrNone<[AnyType]>:$bias,
TFL_PaddingAttr:$padding,
I32Attr:$stride_h,
I32Attr:$stride_w
);
let results = (outs AnyTensor:$output);
}
def TFL_AveragePool2DOp:
TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Average_pool_2d operator";
@ -471,7 +498,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType().
return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
@ -500,7 +527,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType().
return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
@ -1181,7 +1208,8 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
let builders = [TFL_BroadcastableBinaryBuilder];
}
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
def TFL_GreaterOp : TFL_Op<"greater", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
let summary = "Greater operator";
let description = [{
@ -1194,6 +1222,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
let results = (outs AnyTensor:$output);
let builders = [TFL_ComparisonBinaryBuilder];
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
@ -1260,7 +1290,8 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
let hasOptions = 0b1;
}
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, NoQuantizableResult]> {
def TFL_LessOp : TFL_Op<"less", [
Broadcastable, NoSideEffect, NoQuantizableResult]> {
let summary = "Less operator";
let description = [{
@ -1427,6 +1458,63 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
let customOption = "Pool2DOptions";
}
def TFL_MaxPoolingWithArgMax2DOp :
Op<TFL_Dialect, "max_pooling_with_argmax_2d", [NoSideEffect]> {
let summary = "Max Pool 2D with argmax op";
let description = [{
Performs max pooling on the input and outputs both max values and indices.
Each index is a flatten index in a sub-array of "filter_w" x "filter_h" size
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
}];
let arguments = (
ins AnyTensor:$input,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_w,
I32Attr:$filter_h
);
let results = (outs
AnyTensor:$value,
AnyTensor:$indices
);
}
def TFL_MaxUnpooling2DOp :
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect]> {
let summary = "Max Unpool 2D";
let description = [{
Performs max unpool operation.
To some extent this is the reverse operation of max pooling:
the elements in the input activation tensor is stored into the position
specified by the input indices.
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the input indices
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$indices,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_w,
I32Attr:$filter_h
);
let results = (outs AnyTensor:$outputs);
}
def TFL_MaximumOp : TFL_Op<"maximum", [
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale,
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
@ -1996,7 +2084,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
let results = (outs AnyTensor:$output);
DerivedTypeAttr out_type = DerivedTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType();
return getResult().getType().cast<TensorType>().getElementType();
}]>;
let hasOptions = 1;
@ -2039,7 +2127,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
Args:
tensor: A Tensor. Must be one of the following types:
int16, int32, int64, float32 Up to 8-D.
uint8, int16, int32, int64, float32, bool Up to 8-D.
axis: A Tensor. Must be one of the following types: int32, int64.
with only 1 element which is the axis index.
@ -2048,12 +2136,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
let arguments = (
ins
TensorOf<[F32, I16, I32, I64]>:$input,
TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input,
TensorOf<[I32, I64]>:$axis
);
let results = (outs
TensorOf<[F32, I16, I32, I64, I8]>:$output
TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output
);
}
@ -2083,7 +2171,7 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
"Value condition, Value x, Value y",
[{
auto resultType = x->getType();
auto resultType = x.getType();
result.addOperands({condition, x, y});
result.types.push_back(resultType);
}]>];
@ -2733,7 +2821,7 @@ in the unique output `y`. In other words:
);
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
return getResult(1)->getType().cast<TensorType>().getElementType().
return getResult(1).getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;

View File

@ -19,7 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:llvm-project
namespace mlir {
namespace OpTrait {

View File

@ -30,10 +30,10 @@ limitations under the License.
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Parser.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Parser.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
#include "tensorflow/core/platform/init_main.h"

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Attribute.h" // TF:local_config_mlir
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
using llvm::DefInit;
using llvm::dyn_cast;

View File

@ -28,10 +28,10 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:Support",
"@local_config_mlir//:ViewOpGraph",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)

View File

@ -19,11 +19,11 @@ limitations under the License.
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/Transforms/ViewOpGraph.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
@ -107,9 +107,6 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
if (toco_flags.output_format()) {
LOG(WARNING) << "Ignored output_format.";
}
if (toco_flags.default_ranges_min() || toco_flags.default_ranges_max()) {
LOG(WARNING) << "Ignored default_ranges_stats.";
}
if (toco_flags.drop_control_dependency()) {
LOG(WARNING) << "Ignored drop_control_dependency.";
}
@ -242,6 +239,13 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
// Other flags.
if (toco_flags.has_default_ranges_min()) {
quant_specs.default_ranges.first = toco_flags.default_ranges_min();
}
if (toco_flags.has_default_ranges_max()) {
quant_specs.default_ranges.second = toco_flags.default_ranges_max();
}
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops();

View File

@ -13,7 +13,7 @@ package(
package_group(
name = "friends",
includes = ["@local_config_mlir//:subpackages"],
includes = ["//third_party/mlir:subpackages"],
packages = ["//tensorflow/compiler/mlir/..."],
)
@ -26,8 +26,8 @@ filegroup(
name = "quantization_td_files",
srcs = [
"quantization.td",
"@local_config_mlir//:OpBaseTdFiles",
"@local_config_mlir//:QuantizationOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:QuantizationOpsTdFiles",
],
)
@ -53,13 +53,13 @@ cc_library(
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
@ -75,11 +75,11 @@ cc_library(
],
deps = [
"@com_google_absl//absl/memory",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(fengliuai): remove this dependence.
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:lib_proto_parsing",
@ -97,7 +97,7 @@ cc_library(
deps = [
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)
@ -107,8 +107,8 @@ tf_native_cc_binary(
"tools/op_quant_spec_getters_gen.cc",
],
deps = [
"@llvm//:support",
"@llvm//:tablegen",
"@local_config_mlir//:TableGen",
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//mlir:TableGen",
],
)

View File

@ -23,18 +23,18 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/AffineExpr.h" // TF:local_config_mlir
#include "mlir/IR/AffineMap.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
#include "mlir/IR/AffineMap.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
@ -78,8 +78,8 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
bool IsQuantizableResult(Operation *op, int index) {
if (index < 0 || index >= op->getNumResults()) return false;
Value res = op->getResult(index);
return res->getType().isa<ShapedType>() &&
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
return res.getType().isa<ShapedType>() &&
res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
}
// A method to retrieve the name for the given op.
@ -123,7 +123,7 @@ void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
IntegerAttr axis) {
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
layer_stats, axis_stats, axis);
res->replaceAllUsesWith(stats_op);
res.replaceAllUsesWith(stats_op);
stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
}

View File

@ -9,7 +9,7 @@ package(
package_group(
name = "friends",
includes = ["@local_config_mlir//:subpackages"],
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/...",
"//tensorflow/lite/...",
@ -36,9 +36,9 @@ cc_library(
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
)
@ -53,6 +53,6 @@ tf_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)

View File

@ -17,11 +17,11 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h"
#include "tensorflow/core/framework/types.pb.h"
@ -64,6 +65,10 @@ struct QuantizationSpecs {
// quantization aware training or calibration, for the remaining tensors.
std::vector<std::pair<double, double>> input_ranges;
// The default ranges can be used when a tensor doesn't have quantization
// parameters and couldn't be quantized. Used only for latency tests.
std::pair<llvm::Optional<double>, llvm::Optional<double>> default_ranges;
// A serialized "QuantizationInfo" object to specify value ranges for some of
// the tensors with known names.
std::string serialized_quant_stats = "";

View File

@ -23,17 +23,17 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
@ -146,14 +146,14 @@ class QuantizationDriver {
// Adds all the users of index-th result of op to the work list.
void AddUserToList(Operation *op, int index) {
for (auto *user : op->getResult(index)->getUsers()) {
for (auto *user : op->getResult(index).getUsers()) {
work_list_.push_back(user);
}
}
// Adds the defining op of index-th operand of op to the work list.
void AddOperandToList(Operation *op, int index) {
if (auto *inst = op->getOperand(index)->getDefiningOp()) {
if (auto *inst = op->getOperand(index).getDefiningOp()) {
work_list_.push_back(inst);
}
}
@ -248,7 +248,7 @@ class QuantizationDriver {
return;
}
QuantParams params =
quant::QuantizedType::getQuantizedElementType(in->getType());
quant::QuantizedType::getQuantizedElementType(in.getType());
bool immutable = !EmptyParams(params);
int next_state_index = states_.size();
states_.push_back({params, immutable});
@ -338,7 +338,7 @@ bool QuantizationDriver::IsQuantized(Operation *op) {
int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
bool as_result) {
QuantParams params =
quant::QuantizedType::getQuantizedElementType(val->getType());
quant::QuantizedType::getQuantizedElementType(val.getType());
bool immutable = !EmptyParams(params);
int next_state_index = states_.size();
states_.push_back({params, immutable});
@ -447,13 +447,13 @@ void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
}
void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
builder_.setInsertionPointToStart(arg->getOwner());
builder_.setInsertionPointToStart(arg.getOwner());
QuantizeValue(arg, params, builder_.getUnknownLoc());
}
void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
Location loc) {
Type expressed_type = value->getType();
Type expressed_type = value.getType();
Type new_type = params.castFromExpressedType(expressed_type);
// This value isn't an expressed type (float), skip.
if (!new_type) return;
@ -465,7 +465,7 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
quantize.output());
// `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards
value->replaceAllUsesWith(dequantize);
value.replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
}
@ -475,7 +475,7 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
builder_.setInsertionPointAfter(op);
Value value = op->getResult(index);
if (state->pos == RequantizeState::ON_OUTPUT) {
Operation *user = value->getUses().begin().getUser();
Operation *user = value.getUses().begin().getUser();
if (llvm::isa<TFL::QuantizeOp>(user)) {
// The requantize op is inserted between `quantize` and `dequantize` ops.
value = user->getResult(0);
@ -488,12 +488,12 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
void QuantizationDriver::RequantizeArg(BlockArgument arg,
RequantizeState *state) {
Value value = arg;
builder_.setInsertionPointToStart(arg->getOwner());
if (value->hasOneUse()) {
auto user = value->use_begin().getUser();
builder_.setInsertionPointToStart(arg.getOwner());
if (value.hasOneUse()) {
auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output();
builder_.setInsertionPoint(arg->getOwner(), ++Block::iterator(user));
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
}
}
RequantizeValue(value, state, builder_.getUnknownLoc());
@ -503,13 +503,13 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
Location loc) {
Type new_type;
if (state->pos == RequantizeState::ON_INPUT) {
Type expressed_type = value->getType();
Type expressed_type = value.getType();
// The value needs to be requantized. A Quantize op will be created to use
// it as the operand and replace its uses.
new_type = state->params.castFromExpressedType(expressed_type);
} else {
Type expressed_type =
quant::QuantizedType::castToExpressedType(value->getType());
quant::QuantizedType::castToExpressedType(value.getType());
if (!expressed_type) return;
// The value needs to be requantized. A Quantize op will be created to use
@ -522,7 +522,7 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
TypeAttr type_attr = TypeAttr::get(new_type);
auto requantize_op =
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
value->replaceAllUsesWith(requantize_op);
value.replaceAllUsesWith(requantize_op);
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
}
@ -603,7 +603,7 @@ void QuantizationDriver::PreprocessConstantOps() {
Value value = cst.getResult();
SmallVector<std::pair<Operation *, int>, 4> bias_users;
bool used_as_weight = false;
for (auto &use : value->getUses()) {
for (auto &use : value.getUses()) {
auto spec = GetQuantSpec(use.getOwner());
auto biases = spec->biases_params;
Operation *user = use.getOwner();
@ -649,8 +649,8 @@ void QuantizationDriver::SetupAllStates() {
args_.push_back(arg);
Value value = arg;
// If the argument is quantized, it should only has one user.
if (arg->hasOneUse()) {
auto user = value->use_begin().getUser();
if (arg.hasOneUse()) {
auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output();
}
@ -666,7 +666,7 @@ void QuantizationDriver::SetupAllStates() {
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
auto operand = op->getOperand(i);
if (auto *inst = operand->getDefiningOp()) {
if (auto *inst = operand.getDefiningOp()) {
// If the operand comes from a tfl.dequantize op, we use the quantized
// input of this tfl.dequantize op to set the state.
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
@ -677,12 +677,12 @@ void QuantizationDriver::SetupAllStates() {
}
for (int res = 0, e = op->getNumResults(); res != e; ++res) {
auto result = op->getResult(res);
Value result = op->getResult(res);
// If the result has been quantized, it should only be used by a
// tfl.quantize op. For this case, we uses the quantized result to
// create the state and mark it immutable.
if (result->hasOneUse()) {
auto user = result->use_begin().getUser();
if (result.hasOneUse()) {
auto user = result.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
result = q.output();
}

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
namespace mlir {
namespace quant {

View File

@ -18,8 +18,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
namespace mlir {
namespace OpTrait {
@ -70,7 +70,7 @@ class FixedResultUniformScale {
QuantizedType GetResultQuantizedType(int index) {
auto op = this->getOperation();
auto result_type =
op->getResult(index)->getType().template cast<TensorType>();
op->getResult(index).getType().template cast<TensorType>();
Builder builder(op->getContext());
IntegerType storage_type = builder.getIntegerType(BitWidth);
const double scale = static_cast<double>(ScaleMantissa) *

View File

@ -21,15 +21,15 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
namespace mlir {
@ -367,7 +367,7 @@ ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
static bool PreferResultScale(Operation* op) {
int float_operands = 0;
for (auto operand : op->getOperands()) {
if (auto operand_type = operand->getType().dyn_cast<ShapedType>()) {
if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
if (operand_type.getElementType().isa<FloatType>()) {
if (float_operands++ > 1) return true;
}
@ -400,22 +400,22 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
quant::StatisticsOp stats_op = all_stats_ops.back();
all_stats_ops.pop_back();
if (auto def = stats_op.arg()->getDefiningOp()) {
if (auto def = stats_op.arg().getDefiningOp()) {
if (IsStatsRedundant(def, op_quant_spec_getter)) {
redundant_stats_ops.insert(stats_op);
}
}
for (auto user : stats_op.getResult()->getUsers()) {
for (auto user : stats_op.getResult().getUsers()) {
// We don't propagate this parameter down if it has multiple operands.
// We want to use the result parameter scales instead.
if (user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
!PreferResultScale(user)) {
for (Value res : user->getResults()) {
if (res->hasOneUse()) {
if (res.hasOneUse()) {
if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
*res->getUsers().begin())) {
*res.getUsers().begin())) {
// quantization parameters can be propagated to next_stats
redundant_stats_ops.insert(next_stats);
// add next_stats to the work list so propagation can
@ -440,12 +440,12 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
quant::StatisticsOp stats_op = all_stats_ops.back();
all_stats_ops.pop_back();
if (auto def = stats_op.arg()->getDefiningOp()) {
if (auto def = stats_op.arg().getDefiningOp()) {
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
PreferResultScale(def)) {
for (auto input : def->getOperands()) {
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
input->getDefiningOp())) {
input.getDefiningOp())) {
redundant_stats_ops.insert(next_stats);
all_stats_ops.push_back(next_stats);
}
@ -458,7 +458,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
for (auto it : redundant_stats_ops) {
if (!llvm::isa<quant::StatisticsOp>(it)) return true;
auto stats_op = llvm::cast<quant::StatisticsOp>(it);
stats_op.getResult()->replaceAllUsesWith(stats_op.arg());
stats_op.getResult().replaceAllUsesWith(stats_op.arg());
stats_op.erase();
}

View File

@ -23,18 +23,18 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir {
@ -116,7 +116,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
TypeAttr::get(result_type));
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
op.getResult()->replaceAllUsesWith(dq);
op.getResult().replaceAllUsesWith(dq);
q.getOperation()->replaceUsesOfWith(dq, op.arg());
op.erase();
@ -162,7 +162,7 @@ struct QuantizationPattern : public RewritePattern {
return matchFailure();
}
Value quantized_value = op->getResult(0);
for (Operation* quantized_op : quantized_value->getUsers()) {
for (Operation* quantized_op : quantized_value.getUsers()) {
// If it is requantize op, we shouldn't rewrite this op.
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
return matchFailure();
@ -179,14 +179,14 @@ struct QuantizationPattern : public RewritePattern {
SmallVector<Value, 4> inputs;
inputs.reserve(quantized_op->getNumOperands());
for (auto operand : quantized_op->getOperands()) {
Type operand_type = operand->getType();
Type operand_type = operand.getType();
if (operand_type.isa<NoneType>()) {
inputs.push_back(operand);
continue;
}
auto ele_type = operand->getType().cast<TensorType>().getElementType();
if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) {
auto ele_type = operand.getType().cast<TensorType>().getElementType();
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
inputs.push_back(op_inst.input());
} else if (ele_type.isa<IntegerType>()) {
// If the operand is an integer tensor, then it doesn't require the
@ -207,7 +207,7 @@ struct QuantizationPattern : public RewritePattern {
for (auto enumerated_result :
llvm::enumerate(quantized_op->getResults())) {
Value result = enumerated_result.value();
Type result_type = result->getType();
Type result_type = result.getType();
// Add this to the test coverage once we create test ops with none type
// results.
if (result_type.isa<NoneType>()) {
@ -216,20 +216,20 @@ struct QuantizationPattern : public RewritePattern {
continue;
}
Type result_ele_type =
result->getType().cast<TensorType>().getElementType();
result.getType().cast<TensorType>().getElementType();
// If the user is the Quantize op, it must be the only user.
if (result->hasOneUse() && llvm::isa<Q>(*result->user_begin())) {
auto user = llvm::cast<Q>(*result->user_begin());
if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
auto user = llvm::cast<Q>(*result.user_begin());
outputs_replaced.insert({user.output(), enumerated_result.index()});
output_types.push_back(user.getType());
} else if (result_ele_type.template isa<IntegerType>()) {
// If the result is an integer tensor, then it doesn't require the
// D op in the pattern.
outputs_replaced.insert({result, enumerated_result.index()});
output_types.push_back(result->getType());
output_types.push_back(result.getType());
} else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) {
outputs_replaced.insert({result, enumerated_result.index()});
output_types.push_back(result->getType());
output_types.push_back(result.getType());
} else {
return matchFailure();
}
@ -241,7 +241,7 @@ struct QuantizationPattern : public RewritePattern {
output_types, quantized_op->getAttrs());
Operation* new_op = rewriter.createOperation(new_state);
for (auto output : outputs_replaced) {
output.getFirst()->replaceAllUsesWith(
output.getFirst().replaceAllUsesWith(
new_op->getResult(output.getSecond()));
}
@ -252,7 +252,7 @@ struct QuantizationPattern : public RewritePattern {
// For constant operands, the floating-point constant is duplicated in
// case it is quantized.
for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
auto def = new_op->getOperand(i)->getDefiningOp();
auto def = new_op->getOperand(i).getDefiningOp();
if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
DenseFPElementsAttr attr;
if (!matchPattern(q.input(), m_Constant(&attr))) {
@ -265,7 +265,7 @@ struct QuantizationPattern : public RewritePattern {
for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
if (!quantized_op->getResult(i)
->getType()
.getType()
.cast<ShapedType>()
.getElementType()
.isa<FloatType>()) {
@ -283,13 +283,13 @@ struct QuantizationPattern : public RewritePattern {
// Find the Dequantize/Dequantize users of the new op results, and
// replace the usage. Then all the floating-point ops are connected.
// N.B. the return op will use this floating-point result.
for (auto user : new_op->getResult(i)->getUsers()) {
for (auto user : new_op->getResult(i).getUsers()) {
// Skip the Requantize op, and we know it has a single user.
if (llvm::isa<Q>(user)) {
user = *user->getResult(0)->getUsers().begin();
user = *user->getResult(0).getUsers().begin();
}
if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
dequantize.getResult()->replaceAllUsesWith(
dequantize.getResult().replaceAllUsesWith(
quantized_op->getResult(i));
}
}
@ -316,7 +316,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
PatternMatchResult matchAndRewrite(Q op,
PatternRewriter& rewriter) const override {
Type output_type = op.output()->getType();
Type output_type = op.output().getType();
auto qtype = QType::getQuantizedElementType(output_type);
if (!qtype || qtype.isSigned()) return this->matchFailure();

View File

@ -4,7 +4,7 @@ package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
@ -14,6 +14,6 @@ filegroup(
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm//:FileCheck",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Operator.h" // TF:local_config_mlir
#include "mlir/TableGen/Operator.h" // TF:llvm-project
using llvm::LessRecord;
using llvm::raw_ostream;

View File

@ -4,7 +4,7 @@ package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
@ -14,6 +14,6 @@ filegroup(
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm//:FileCheck",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -7,7 +7,7 @@ glob_lit_tests(
":debug_info_files",
":test_utilities",
],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = [
"pbtxt",
# TODO(fengliuai): reenable these tests after the fused loc is
@ -33,8 +33,8 @@ filegroup(
":saved_model_error",
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
"@llvm//:FileCheck",
"@llvm//:not",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

View File

@ -0,0 +1,89 @@
// RUN: tf-opt %s --tfl-default-quant --tfl-quantize | FileCheck %s
// CHECK-LABEL: hardcode_all
func @hardcode_all(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
// Quantized tfl.add
// CHECK: %[[add:.*]] = "tfl.add"(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}
// CHECK-LABEL: hardcode_input
func @hardcode_input(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
%1 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
%4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
return %4 : tensor<2x2xf32>
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>}
// CHECK: %[[add:.*]] = "tfl.add"(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}
// CHECK-LABEL: hardcode_input_deq
func @hardcode_input_deq(%arg0: tensor<2x2x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
%1 = "tfl.dequantize"(%arg0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0>>) -> tensor<2x2xf32>
%4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
return %4 : tensor<2x2xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
// CHECK: %[[add:.*]] = "tfl.add"(%arg0, %[[q]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}
// CHECK-LABEL: hardcode_output
func @hardcode_output(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
%1 = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x1xf32>) -> tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>
%2 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
%3 = "tfl.dequantize"(%1) : (tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x1xf32>
%4 = "tfl.add"(%2, %3) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
return %4 : tensor<2x2xf32>
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.000000e+00:128>>}
// CHECK: %[[add:.*]] = "tfl.add"(%[[q0]], %[[q1]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}
// CHECK-LABEL: test_conv_2d_add
func @test_conv_2d_add(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>, %arg2: tensor<32x!quant.uniform<i32:f32, 1.0>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x224x224x3xf32>
%1 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
%2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<i32:f32, 1.0>>) -> tensor<32xf32>
%3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%4 = "tfl.pseudo_qconst"() {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>, value = dense<1> : tensor<1x112x112x32xi8>} : () -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
%5 = "tfl.dequantize"(%4) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
%6 = "tfl.add"(%3, %5) {fused_activation_function="NONE"}: (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
%7 = "tfl.quantize"(%6) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
return %7 : tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %arg1, %arg2)
// CHECK-SAME: -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"()
// CHECK: %[[add:.*]] = "tfl.add"(%[[conv]], %[[cst]])
// CHECK-SAME: -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: return %[[add]]
}
// CHECK-LABEL: test_conv_2d_activation_and_bias
func @test_conv_2d_activation_and_bias(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>, %arg2: tensor<32xf32>) -> tensor<1x112x112x32xf32> {
%0 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
%1 = "tfl.conv_2d"(%arg0, %0, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
return %1 : tensor<1x112x112x32xf32>
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg2) {qtype = tensor<32x!quant.uniform<i32:f32, 0.0078431372549019607>>}
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%[[q1]], %arg1, %[[q0]])
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[conv]]) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}

View File

@ -7,7 +7,7 @@ glob_lit_tests(
":quant_stats_files",
":test_utilities",
],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = [
"pbtxt",
],
@ -20,7 +20,7 @@ filegroup(
data = [
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
"@llvm//:FileCheck",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -8,7 +8,7 @@ glob_lit_tests(
":extra_files",
":test_utilities",
],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = [
"mlir",
"cc",
@ -24,7 +24,7 @@ filegroup(
":importer_test_min_max",
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
"@llvm//:FileCheck",
"@llvm-project//llvm:FileCheck",
],
)
@ -51,7 +51,7 @@ tf_native_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)
@ -67,6 +67,6 @@ tf_native_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@llvm-project//llvm:support",
],
)

View File

@ -11,6 +11,8 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
// CHECK: %[[EXP:.*]] = "tfl.exp"
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
// tfl.neg should not be pruned
// CHECK: %[[NEG:.*]] = "tfl.neg"
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
return %5 : tensor<4xf32>

View File

@ -0,0 +1,19 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -output-arrays=mul,exp,div --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// Confirm graph pruning.
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
// CHECK: %[[MUL:.*]] = tfl.mul
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
// CHECK: %[[DIV:.*]] = tfl.div
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
// CHECK: %[[EXP:.*]] = "tfl.exp"
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
// tfl.neg should be pruned
// CHECK-NOT: "tfl.neg"
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
return %5 : tensor<4xf32>
}

View File

@ -4,7 +4,7 @@ licenses(["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
@ -15,7 +15,7 @@ filegroup(
data = [
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
"@llvm//:FileCheck",
"@llvm//:not",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

View File

@ -0,0 +1,76 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "Convolution2DTransposeBias"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 32, 4, 4, 128 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 32, 42, 128 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "arg2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 64, 84, 32 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "tfl.convolution_2d_transpose_bias",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT:}
// MLIR-LABEL: func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>)
// MLIR-SAME: -> tensor<1x64x84x32xf32>
// MLIR: %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2)
// MLIR-SAME: {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32}
// MLIR-SAME: (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
// MLIR-NEXT: return %0 : tensor<1x64x84x32xf32>
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
return %0 : tensor<1x64x84x32xf32>
}

View File

@ -0,0 +1,39 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
// CHECK: {
// CHECK: version: 3,
// CHECK: operator_codes: [ {
// CHECK: builtin_code: CUSTOM,
// CHECK: custom_code: "HashTableV2"
// CHECK: } ],
// CHECK: subgraphs: [ {
// CHECK: tensors: [ {
// CHECK: shape: [ ],
// CHECK: type: INT32,
// CHECK: buffer: 1,
// CHECK: name: "tf.HashTableV2",
// CHECK: quantization: {
// CHECK-EMPTY
// CHECK: }
// CHECK: } ],
// CHECK: inputs: [ ],
// CHECK: outputs: [ 0 ],
// CHECK: operators: [ {
// CHECK: inputs: [ ],
// CHECK: outputs: [ 0 ],
// CHECK: custom_options:
// CHECK: name: "main"
// CHECK: } ],
// CHECK: description: "MLIR Converted.",
// CHECK: buffers: [ {
// CHECK-EMPTY
// CHECK: }, {
// CHECK-EMPTY
// CHECK: } ]
// CHECK: }
func @main() -> tensor<*x!tf.resource> {
%0 = "tf.HashTableV2"() {container = "" , shared_name= "table", use_node_name_sharing = false, key_dtype = i32, value_dtype = i32 } : () -> tensor<*x!tf.resource>
return %0 : tensor<*x!tf.resource>
}

View File

@ -0,0 +1,65 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "MaxPoolingWithArgmax2D"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1, 64, 64, 32 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d:1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1, 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1, 2 ],
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT:}
// MLIR-LABEL: func @main(%arg0: tensor<1x64x64x32xf32>)
// MLIR-SAME: -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
// MLIR: %value, %indices = "tfl.max_pooling_with_argmax_2d"(%arg0)
// MLIR-SAME: {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32}
// MLIR-SAME: (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
// MLIR-NEXT: return %value, %indices : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
}

View File

@ -0,0 +1,65 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "MaxUnpooling2D"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tfl.max_unpooling_2d",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT:}
// MLIR-LABEL: func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>)
// MLIR-SAME: -> tensor<1x8x8x128xf32>
// MLIR: %0 = "tfl.max_unpooling_2d"(%arg0, %arg1)
// MLIR-SAME: {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32}
// MLIR-SAME: (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32>
// MLIR-NEXT: return %0 : tensor<1x8x8x128xf32>
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
return %0 : tensor<1x8x8x128xf32>
}

View File

@ -518,6 +518,20 @@ func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform<i9:f32
// -----
func @testMaxPoolingWithArgMax2D(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
}
// -----
func @testMaxUnpooling2D(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
return %0 : tensor<1x8x8x128xf32>
}
// -----
// CHECK-LABEL: testLogistic
func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
^bb0(%arg0: tensor<1x2x3x4x5xbf16>):
@ -1942,6 +1956,13 @@ func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %ar
// -----
func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
return %0 : tensor<1x64x84x32xf32>
}
// -----
func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> {
// expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}}
%0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32>

View File

@ -1,4 +1,7 @@
// Run optimize pass only and check the results.
// RUN: tf-opt %s -tfl-optimize | FileCheck %s
// Run optimize pass and then canonicalize pass, and make sure some folding is applied.
// RUN: tf-opt %s -tfl-optimize -canonicalize | FileCheck --check-prefix=FOLD %s
// CHECK-LABEL: fusedConv2dRelu
func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
@ -302,6 +305,58 @@ func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf
// CHECK: return %[[fc]]
}
// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst
// FOLD-LABEL: @FuseFullyConnectedReshapeAddConst
func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant dense<3.0> : tensor<40x40xf32>
%cst2 = constant dense<2.0> : tensor<40xf32>
%shape1 = constant dense<[1, 40, 40]> : tensor<3xi32>
%shape2 = constant dense<[40, 40]> : tensor<2xi32>
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
%1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32>
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32>
%3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32>
return %3 : tensor<40x40xf32>
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%[[fc]]
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
// CHECK: return %[[rs2]]
// FOLD: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
// FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// FOLD: return %[[fc]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastable
func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
%cst = constant dense<2.0> : tensor<40xf32>
%shape = constant dense<[40, 40]> : tensor<2xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x10x4xf32>, tensor<2xi32>) -> tensor<40x40xf32>
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40xf32>) -> tensor<40x40xf32>
return %2 : tensor<40x40xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
%cst = constant dense<2.0> : tensor<1x40xf32>
%shape = constant dense<[40, 40]> : tensor<2xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<1x40xf32>) -> tensor<40x40xf32>
return %2 : tensor<40x40xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @FuseFullyConnectedRelu
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>

View File

@ -15,11 +15,11 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -43,6 +43,16 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
if (quant_specs.default_ranges.first.hasValue() ||
quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0)));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
}
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
namespace tensorflow {

View File

@ -20,11 +20,11 @@ limitations under the License.
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
@ -103,7 +103,7 @@ static int PrintFunctionResultMapping(const std::string &result,
i = 0;
for (auto output : *subgraph->outputs()) {
print_buffer(*subgraph, i, output, [&](int i) {
return terminator ? terminator->getOperand(i)->getLoc() : unknown_loc;
return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
});
}
}

View File

@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Parser.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Parser.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/stream_executor/lib/statusor.h"

View File

@ -0,0 +1,234 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
//===----------------------------------------------------------------------===//
// The Pass to add default quantization parameters for the activations which
// don't have quantization information. These default parameters are usually
// not from real measurement, so this pass is only for test purpose.
namespace mlir {
namespace TFL {
// Includs an auto-generated function, which can retrieve the quantization
// specification for an TFL operation. The signature of the function is
// std::unique_pointer<OpQuantSpec> TFL::GetOpQuantSpec(Operation *)
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
namespace {
class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
public:
explicit DefaultQuantParamsPass(double default_min, double default_max)
: default_min_(default_min), default_max_(default_max) {}
void runOnFunction() override;
private:
// Whether the value is used as a bias input of another op. Here we assume
// bias is used immediately by the user. This assumption is always correct
// after constant folding.
bool UsedAsBias(Value value) {
for (auto &use : value.getUses()) {
auto biases = TFL::GetOpQuantSpec(use.getOwner())->biases_params;
if (biases.find(use.getOperandNumber()) != biases.end()) return true;
}
return false;
}
// Uses `quant_params` to quantize `value` and inserting a pair of
// tfl.quantize and tfl.dequantize ops for this `value`.
void QuantizeValue(OpBuilder builder, Value value,
TFL::QuantParams quant_params);
// If the value hasn't been quantized, the functions adds it to `values`.
void AddToWorkListIfUnquantized(Value value, std::vector<Value> *values);
// Converts the default min/max to the default quantization parameters.
TFL::QuantParams GetDefaultQuantParams(Builder builder);
// Gets the quantization parameters for the bias of an operation by using the
// quantization parameters from the non-biases operands.
TFL::QuantParams GetQuantParamsForBias(Operation *op, int bias,
const std::vector<int> &non_biases,
TFL::AccumulatorScaleFunc func);
double default_min_;
double default_max_;
TFL::QuantParams default_quant_params_;
};
} // namespace
void DefaultQuantParamsPass::runOnFunction() {
FuncOp func = getFunction();
OpBuilder builder(func);
std::vector<Value> activation_values;
std::vector<Value> bias_values;
// First of all, collect all the values (block arguments and op results) which
// are required to be quantized.
for (auto arg : func.getBody().begin()->getArguments()) {
if (UsedAsBias(arg)) {
AddToWorkListIfUnquantized(arg, &bias_values);
} else {
AddToWorkListIfUnquantized(arg, &activation_values);
}
}
func.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
return;
for (auto res : op->getResults()) {
if (UsedAsBias(res)) {
AddToWorkListIfUnquantized(res, &bias_values);
} else {
AddToWorkListIfUnquantized(res, &activation_values);
}
}
});
// Apply the default quantization parameters for these activation values.
TFL::QuantParams default_params = GetDefaultQuantParams(builder);
for (Value value : activation_values) {
QuantizeValue(builder, value, default_params);
}
// Since all the non-biases operands have quantization parameters now, we
// should be able to propagate them to the bias operand.
for (Value bias : bias_values) {
Operation *op = *bias.user_begin();
auto spec = TFL::GetOpQuantSpec(op);
for (auto &it : spec->biases_params) {
TFL::QuantParams bias_params = GetQuantParamsForBias(
op, it.first, it.second.first, it.second.second);
if (!bias_params) continue;
QuantizeValue(builder, bias, bias_params);
}
}
}
void DefaultQuantParamsPass::AddToWorkListIfUnquantized(
Value value, std::vector<Value> *values) {
// If the result isn't with float type, this result is an integer tensor and
// doesn't require quantization.
auto tensor_type = value.getType().dyn_cast<TensorType>();
if (!tensor_type) {
// There are none type values.
return;
}
if (!tensor_type.getElementType().isF32()) return;
// If the result is consumed by a quantize op, it has been quantized.
if (value.hasOneUse() &&
llvm::isa<TFL::QuantizeOp>(*value.getUsers().begin()))
return;
// Add this result to the list to apply the default value.
values->push_back(value);
}
void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value,
TFL::QuantParams quant_params) {
Type expressed_type = value.getType();
Type new_type = quant_params.castFromExpressedType(expressed_type);
// This value isn't an expressed type (float), skip.
if (!new_type) return;
Block &block = value.getParentRegion()->front();
Operation *op = value.getDefiningOp();
if (op) {
builder.setInsertionPoint(&block, ++Block::iterator(op));
} else {
builder.setInsertionPointToStart(&block);
}
TypeAttr type_attr = TypeAttr::get(new_type);
auto quantize = builder.create<TFL::QuantizeOp>(value.getLoc(), new_type,
value, type_attr);
auto dequantize = builder.create<TFL::DequantizeOp>(
value.getLoc(), expressed_type, quantize.output());
value.replaceAllUsesWith(dequantize);
// `quantize` is using `dequantize` now, so we should set its operand to
// `value`.
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
}
TFL::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
Operation *op, int bias, const std::vector<int> &non_biases,
TFL::AccumulatorScaleFunc func) {
std::vector<quant::QuantizedType> non_bias_types;
non_bias_types.reserve(non_biases.size());
for (int non_bias : non_biases) {
Operation *non_bias_define = op->getOperand(non_bias).getDefiningOp();
if (auto dequant = llvm::dyn_cast<TFL::DequantizeOp>(non_bias_define)) {
auto non_bias_type = dequant.input().getType().cast<TensorType>();
auto non_bias_ele_type =
non_bias_type.getElementType().cast<quant::QuantizedType>();
non_bias_types.push_back(non_bias_ele_type);
} else {
// The non-bias hasn't been quantized, let's skip this bias.
break;
}
}
// The non-bias hasn't been quantized, let's skip this bias.
if (non_bias_types.size() != non_biases.size()) return {};
return func(non_bias_types);
}
TFL::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
Builder builder) {
if (!default_quant_params_) {
default_quant_params_ = quant::fakeQuantAttrsToType(
builder.getUnknownLoc(),
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
builder.getF32Type());
}
return default_quant_params_;
}
// Creates an instance of the default quant parameters pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max) {
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
}
// Registers this pass with default values, only for test
static PassRegistration<DefaultQuantParamsPass> pass(
"tfl-default-quant",
"Apply quantization with default quantization parameter", [] {
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
/*default_max=*/1.0);
});
} // namespace TFL
} // namespace mlir

View File

@ -21,26 +21,26 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
@ -205,7 +205,7 @@ struct OphintCompositeOp {
Operation* current_identity_op = operand.ops.begin()->second;
Value input = current_identity_op->getOperand(0);
RankedTensorType input_type =
input->getType().cast<RankedTensorType>();
input.getType().cast<RankedTensorType>();
// The Reshape will be {1, (original_shape)}
SmallVector<int64_t, 4> reshape_op_shape;
reshape_op_shape.push_back(1);
@ -242,13 +242,13 @@ struct OphintCompositeOp {
}
// Find the first op that consumes the last value of the aggregated
// inputs.
Operation* first_use = *(packed_input_consumers.back()->user_begin());
Operation* first_use = *(packed_input_consumers.back().user_begin());
// The pack reshape will be {N, (original_shape)}
SmallVector<int64_t, 4> pack_shape;
pack_shape.push_back(pack_input_operands.size());
RankedTensorType type = operand.ops.at(0)
->getResult(0)
->getType()
.getType()
.cast<RankedTensorType>();
for (const auto& dim : type.getShape()) {
pack_shape.push_back(dim);
@ -290,7 +290,7 @@ struct OphintCompositeOp {
const int output_numer = operand.ops.size();
Value first_output = operand.ops.at(0)->getOperand(0);
RankedTensorType first_output_type =
first_output->getType().cast<RankedTensorType>();
first_output.getType().cast<RankedTensorType>();
// The aggregated output shape will be {N, original_shape}.
SmallVector<int64_t, 4> shape;
shape.push_back(output_numer);
@ -302,10 +302,10 @@ struct OphintCompositeOp {
} else if (operand.aggregation == kStrategyLast) {
Value last_output =
operand.ops.at(operand.ops.size() - 1)->getOperand(0);
aggregated_output_types[kv.first] = last_output->getType();
aggregated_output_types[kv.first] = last_output.getType();
} else {
Value first_output = operand.ops.at(0)->getOperand(0);
aggregated_output_types[kv.first] = first_output->getType();
aggregated_output_types[kv.first] = first_output.getType();
}
}
return aggregated_output_types;
@ -329,7 +329,7 @@ struct OphintCompositeOp {
Operation* first_output = operand.ops.at(0);
Location insert_loc = first_output->getLoc();
SmallVector<Type, 4> unpack_output_types(
output_number, first_output->getOperand(0)->getType());
output_number, first_output->getOperand(0).getType());
builder->setInsertionPoint(first_output);
Operation* unpack_op = builder->create<TFL::UnpackOp>(
@ -404,7 +404,7 @@ void PreprocessTopoSortGraph(
// should only count as one.
llvm::DenseSet<Operation*> input_ops;
for (int i = 0; i < op.getNumOperands(); ++i) {
Operation* input_op = op.getOperand(i)->getDefiningOp();
Operation* input_op = op.getOperand(i).getDefiningOp();
if (input_op) input_ops.insert(input_op);
}
if (input_ops.empty()) {
@ -515,7 +515,7 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
SmallVector<int, 4> input_indexes;
for (const auto& kv : inputs) {
Value input = kv.second;
input_types.push_back(input->getType());
input_types.push_back(input.getType());
input_values.push_back(input);
input_indexes.push_back(kv.first);
}
@ -589,7 +589,7 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
std::queue<Operation*> ops_queue;
for (auto& input_op : input_ops) {
for (Value value : input_op->getOperands()) {
Operation* op = value->getDefiningOp();
Operation* op = value.getDefiningOp();
if (op != nullptr) ops_queue.push(op);
}
}
@ -599,7 +599,7 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
ops_queue.pop();
reachable_ops.insert(current_op);
for (Value value : current_op->getOperands()) {
Operation* upstream_op = value->getDefiningOp();
Operation* upstream_op = value.getDefiningOp();
// Not visited, put it into the queue.
if (upstream_op != nullptr &&
!llvm::is_contained(reachable_ops, upstream_op)) {
@ -642,7 +642,7 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
aggregated_inputs, aggregated_output_types, builder, module_op);
for (const auto& kv : aggregated_inputs) {
Operation* op = kv.second->getDefiningOp();
Operation* op = kv.second.getDefiningOp();
if (op == nullptr) return failure();
op->moveBefore(fused_op);
}

View File

@ -15,23 +15,23 @@ limitations under the License.
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
@ -103,7 +103,7 @@ LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op,
Value hidden_state = call_op.getOperand(4);
// Build Output.
auto output_type = call_op.getResult(0)->getType();
auto output_type = call_op.getResult(0).getType();
// Currently, ophinted RNN only supports time_major = True.
const bool time_major = true;
@ -170,11 +170,11 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
for (int i = 0; i < call_op.getNumResults() - 1; ++i) {
// This one should not be used.
Value unused_output = call_op.getResult(i);
if (!unused_output->use_empty()) return failure();
if (!unused_output.use_empty()) return failure();
}
}
output_types.push_back(
call_op.getResult(call_op.getNumResults() - 1)->getType());
call_op.getResult(call_op.getNumResults() - 1).getType());
// Prepare attributes.
SmallVector<NamedAttribute, 4> attributes;
@ -207,10 +207,10 @@ LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name,
composite_func_op, call_op, builder, &fused_op);
if (failed(build_fused_op_result)) return build_fused_op_result;
Value call_output = call_op.getResult(call_op.getNumResults() - 1);
if (call_output->getType() != fused_op->getResult(0)->getType()) {
if (call_output.getType() != fused_op->getResult(0).getType()) {
return failure();
}
call_output->replaceAllUsesWith(fused_op->getResult(0));
call_output.replaceAllUsesWith(fused_op->getResult(0));
} else { // If we support more fused op, we should add the conversion here.
return failure();
}

View File

@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
// Use the tensor type information from $0 and convert min $1, max $2 and
// numBits $3 and narrowRange $4 to a QuantizedType.
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
"GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
// Converts an integer attribute $0 to 32-bit with builder.
def convertIntAttrTo32Bit : NativeCodeCall<
@ -50,7 +50,7 @@ def ExtractSingleElementAsInteger : NativeCodeCall<
"ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">;
// Checks whether the given operation has static shapes and same shapes of all inputs.
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0->getDefiningOp())">;
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;

View File

@ -28,15 +28,15 @@ limitations under the License.
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -72,8 +72,8 @@ bool HasSameStaticShapes(Operation* op) {
int index = 0;
ArrayRef<int64_t> shape;
for (Value value : values) {
auto shaped_type = value->getType().dyn_cast<ShapedType>();
if (!shaped_type && !shaped_type.hasStaticShape()) {
auto shaped_type = value.getType().dyn_cast<ShapedType>();
if (!shaped_type || !shaped_type.hasStaticShape()) {
return false;
}
if (index == 0) {
@ -122,7 +122,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
auto tf_concat_op = cast<TF::ConcatOp>(op);
auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output()->getType();
auto output_type = tf_concat_op.output().getType();
// Extract axis attribute from constant concat_dims tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
@ -141,7 +141,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output()->getType();
auto output_type = tf_concat_op.output().getType();
// Extract axis attribute from constant axis tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
@ -167,7 +167,7 @@ PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
if (tf_matmul_op.transpose_a()) return matchFailure();
if (!tf_matmul_op.transpose_b()) return matchFailure();
Type output_type = tf_matmul_op.getResult()->getType();
Type output_type = tf_matmul_op.getResult().getType();
// TODO(jpienaar): Follow up post shuffle discussion.
auto no_input = rewriter.create<ConstantOp>(
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
@ -184,7 +184,7 @@ PatternMatchResult ConvertTFPackOp::matchAndRewrite(
auto tf_pack_op = cast<TF::PackOp>(op);
SmallVector<Value, 4> values(tf_pack_op.values());
auto output_type = tf_pack_op.output()->getType();
auto output_type = tf_pack_op.output().getType();
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
// Axis can be negative.
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue());
@ -201,7 +201,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
auto input = tf_reshape_op.tensor();
auto shape = tf_reshape_op.shape();
ShapedType shape_type = shape->getType().cast<ShapedType>();
ShapedType shape_type = shape.getType().cast<ShapedType>();
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
if (!shape_type.getElementType().isInteger(32)) {
auto new_shape = shape_type.getShape();
@ -213,7 +213,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
rewriter.getBoolAttr(false))
.y();
}
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output()->getType(),
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
input, shape);
return matchSuccess();
}
@ -222,7 +222,7 @@ PatternMatchResult ConvertTFSplitOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_split_op = cast<TF::SplitOp>(op);
auto output_types = functional::map([](Value v) { return v->getType(); },
auto output_types = functional::map([](Value v) { return v.getType(); },
tf_split_op.output());
// Number of splits cannot be negative.
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
@ -237,7 +237,7 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_splitv_op = cast<TF::SplitVOp>(op);
auto output_types = functional::map([](Value v) { return v->getType(); },
auto output_types = functional::map([](Value v) { return v.getType(); },
tf_splitv_op.output());
// Number of splits cannot be negative.
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
@ -254,7 +254,7 @@ Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
DenseIntElementsAttr dense_elem_attr;
SmallVector<int32_t, 8> padded_val;
auto ranked_attr_type = attribute->getType().dyn_cast<RankedTensorType>();
auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
if (!ranked_attr_type ||
!matchPattern(attribute, m_Constant(&dense_elem_attr))) {
// If the input attribute is neither ranked type nor constant, we
@ -280,14 +280,14 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
auto ranked_input_type =
tf_strided_slice_op.input()->getType().dyn_cast<RankedTensorType>();
tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
if (!ranked_input_type) {
// If input is not a ranked tensor, we can't deduce the padding dimensions
// from it, so we just do a plain conversion here.
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output()->getType(),
tf_strided_slice_op.input(), tf_strided_slice_op.begin(),
tf_strided_slice_op.end(), tf_strided_slice_op.strides(),
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
tf_strided_slice_op.strides(),
rewriter.getI32IntegerAttr(
tf_strided_slice_op.begin_mask().getSExtValue()),
rewriter.getI32IntegerAttr(
@ -318,7 +318,7 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
Value padded_strides = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output()->getType(), tf_strided_slice_op.input(),
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
padded_begin, padded_end, padded_strides,
rewriter.getI32IntegerAttr(begin_mask),
rewriter.getI32IntegerAttr(end_mask),
@ -336,7 +336,7 @@ PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
auto tf_unpack_op = cast<TF::UnpackOp>(op);
auto input = tf_unpack_op.value();
auto output_types = functional::map([](Value v) { return v->getType(); },
auto output_types = functional::map([](Value v) { return v.getType(); },
tf_unpack_op.output());
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
// Axis can be negative.
@ -360,7 +360,7 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
auto output_type = tf_matrix_diag_v2_or_v3_op.output()->getType();
auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType();
// Extract k constant tensor and check value = 0.
ElementsAttr k;
@ -500,7 +500,7 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
auto status_or_const_op = CreateConstOpWithSingleValue(
&rewriter, op->getLoc(),
tf_reciprocal_op.x()->getType().cast<ShapedType>(), 1);
tf_reciprocal_op.x().getType().cast<ShapedType>(), 1);
if (!status_or_const_op.ok()) {
return matchFailure();
}

View File

@ -19,11 +19,11 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -71,7 +71,7 @@ struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
Type expressed_type =
lstm.input()->getType().cast<ShapedType>().getElementType();
lstm.input().getType().cast<ShapedType>().getElementType();
Type int8_storage_type = builder->getIntegerType(8);
Type int16_storage_type = builder->getIntegerType(16);
auto flag = quant::QuantizationFlags::FlagValue::Signed;
@ -88,8 +88,8 @@ void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
auto any_int16 = quant::AnyQuantizedType::get(
flag, int16_storage_type, expressed_type, int16_min, int16_max);
int8 = any_int8.castFromExpressedType(lstm.input()->getType());
int16 = any_int16.castFromExpressedType(lstm.input()->getType());
int8 = any_int8.castFromExpressedType(lstm.input().getType());
int16 = any_int16.castFromExpressedType(lstm.input().getType());
}
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in,

View File

@ -29,28 +29,28 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
@ -196,13 +196,13 @@ struct ConvertTensorListSetItem : public ConversionPattern {
// Calculate `index` + 1, which is used to generate the start position for
// the second slice op.
auto suffix_start =
rewriter.create<TF::AddOp>(loc, index->getType(), index,
rewriter.create<TF::AddOp>(loc, index.getType(), index,
CreateI32SplatConst(loc, &rewriter, {}, 1));
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
// Create two slice ops.
Type element_type = input->getType().cast<TensorType>().getElementType();
Type element_type = input.getType().cast<TensorType>().getElementType();
UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
TF::SliceOp slice1 =
@ -225,7 +225,7 @@ struct ConvertTensorListSetItem : public ConversionPattern {
// Concatenate three parts together to generate the final result.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, input->getType(), scalar_zero,
op, input.getType(), scalar_zero,
ArrayRef<Value>({slice1, expanded_item, slice2}));
return matchSuccess();
}
@ -264,7 +264,7 @@ struct ConvertTensorListInitOp : public ConversionPattern {
}
Value element_shape = operands[0];
Type shape_dtype = getElementTypeOrSelf(element_shape->getType());
Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
DenseIntElementsAttr dense_elem_attr;
if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
@ -297,11 +297,10 @@ struct ConvertTensorListInitOp : public ConversionPattern {
new_element_shape_values.push_back(dim_value);
}
auto attr =
DenseIntElementsAttr::get(element_shape->getType().cast<ShapedType>(),
new_element_shape_values);
auto attr = DenseIntElementsAttr::get(
element_shape.getType().cast<ShapedType>(), new_element_shape_values);
auto new_element_shape = rewriter.create<ConstantOp>(
op.getLoc(), element_shape->getType(), attr);
op.getLoc(), element_shape.getType(), attr);
element_shape = new_element_shape;
}
@ -355,7 +354,7 @@ struct ConvertTensorListReserve
Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
PatternRewriter *rewriter) const override {
Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
Type shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
Value num_elements = operands[1];
return rewriter->create<TF::ExpandDimsOp>(
op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
@ -392,14 +391,14 @@ struct ConvertTensorListPushBack : public ConversionPattern {
// Expand the shape of the item so that it will have rank same as the input
// tensor and it is compatible for the Concat Op.
Type expanded_item_type =
PrependLeadingDimIfRanked(1, item->getType(), &rewriter);
PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), expanded_item_type, item, scalar_zero);
Type elem_type = getElementTypeOrSelf(item);
auto handle_dtype =
getElementTypeOrSelf(push_back_op.output_handle()->getType())
getElementTypeOrSelf(push_back_op.output_handle().getType())
.cast<TF::VariantType>();
Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
@ -446,7 +445,7 @@ struct ConvertTensorListResize : public ConversionPattern {
// Infer result type of this op based on TF's shape inference result.
Type elem_type = getElementTypeOrSelf(input_handle);
auto handle_dtype =
getElementTypeOrSelf(resize_op.output_handle()->getType())
getElementTypeOrSelf(resize_op.output_handle().getType())
.cast<TF::VariantType>();
Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
@ -463,8 +462,8 @@ struct ConvertTensorListResize : public ConversionPattern {
auto input_shape = rewriter.create<TF::ShapeOp>(
loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
Type branch_args_type[] = {input_handle->getType(), input_shape.getType(),
size_diff.getType(), size->getType()};
Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
size_diff.getType(), size.getType()};
Type branch_result_type[] = {result_type};
auto func_type = FunctionType::get(branch_args_type, branch_result_type,
rewriter.getContext());
@ -524,7 +523,7 @@ struct ConvertTensorListResize : public ConversionPattern {
loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
slice_size);
auto extended_part = rewriter->create<TF::TensorListReserveOp>(
loc, resize_op.output_handle()->getType(), elem_shape, size_diff);
loc, resize_op.output_handle().getType(), elem_shape, size_diff);
// `ConcatOp` expects non-variant-typed input. Insert a
// `TensorListStackOp` here to convert type from variant to non-variant.
// Note that we are using the same `result_type` for both the
@ -627,7 +626,7 @@ struct ConvertTensorListStack : public ConversionPattern {
// trivial Reshape op (that doesn't actually change the input's shape) and
// also populate the shape info to the op result. The shape of the
// tensorlist is inferred from `num_elements` and `element_shape`.
auto ranked_type = element_shape->getType().dyn_cast<RankedTensorType>();
auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
DenseIntElementsAttr dense_elem_attr;
if ((ranked_type && ranked_type.getRank() == 0) ||
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
@ -659,7 +658,7 @@ struct ConvertIdentity : public ConversionPattern {
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::IdentityOp>(operation);
Value input = operands[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input->getType(), operands,
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
op.getAttrs());
return matchSuccess();
}
@ -687,7 +686,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
Type arg_type = func_type.getInput(i);
if (getElementTypeOrSelf(arg_type).isa<TF::VariantType>()) {
arg_type = UnrankedTensorType::get(
getElementTypeOrSelf(op.getOperand(i)->getType()));
getElementTypeOrSelf(op.getOperand(i).getType()));
}
updated_argument_types.push_back(arg_type);
}
@ -703,7 +702,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
// from the corresponding input operand. This is correct because while
// body's inputs and results have the same type.
result_type = UnrankedTensorType::get(
getElementTypeOrSelf(op.getOperand(i)->getType()));
getElementTypeOrSelf(op.getOperand(i).getType()));
}
updated_result_types.push_back(result_type);
}
@ -717,7 +716,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
// Change the argument type for the first block.
Block &body_first_bb = func.front();
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
body_first_bb.getArgument(i)->setType(updated_argument_types[i]);
body_first_bb.getArgument(i).setType(updated_argument_types[i]);
}
}
return success();
@ -735,12 +734,12 @@ struct ConvertWhile : public ConversionPattern {
llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands());
for (int i = 0, e = operands.size(); i != e; ++i) {
Type result_ty = op.getResult(i)->getType();
Type result_ty = op.getResult(i).getType();
// If we notice the result type is a DT_VARIANT, we change the
// corresponding result type to unranked tensor type.
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
Type element_ty = getElementTypeOrSelf(operands[i]->getType());
Type element_ty = getElementTypeOrSelf(operands[i].getType());
result_ty = UnrankedTensorType::get(element_ty);
}
result_types.push_back(result_ty);

View File

@ -16,6 +16,7 @@ limitations under the License.
// This transformation pass takes operations in TensorFlowLite dialect and
// optimizes them to resulting operations in TensorFlowLite dialect.
#include <algorithm>
#include <climits>
#include <cstdint>
#include <functional>
@ -30,14 +31,14 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
@ -51,15 +52,15 @@ namespace TFL {
namespace {
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
if (sq_op->getType().cast<ShapedType>().getRank() - 1 ==
if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
*axis.getValues<int>().begin() ||
*axis.getValues<int>().begin() == -1) {
return true;
}
if (sq_op->getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
return false;
}
auto shape = sq_op->getType().cast<ShapedType>();
auto shape = sq_op.getType().cast<ShapedType>();
SmallVector<int, 4> elems{axis.getValues<int>().begin(),
axis.getValues<int>().end()};
for (int i = 0; i < shape.getRank(); ++i) {
@ -80,6 +81,18 @@ bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
return OpTrait::util::getBroadcastedType(a, b) != Type();
}
// Returns whether if `type1` dimensions are the same as the ending dimensions
// of `type2`. This is more restricted than broadcastable.
bool IsTailOfShape(Type type1, Type type2) {
auto tail_type = type1.dyn_cast<ShapedType>();
auto full_type = type2.dyn_cast<ShapedType>();
if (!tail_type || !full_type || tail_type.getRank() > full_type.getRank())
return false;
auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend();
auto i2 = full_type.getShape().rbegin();
return std::equal(i1, e1, i2);
}
bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
bool is_depthwise) {
// Make sure the val tensor has shape where all dimensions are 1 except
@ -143,7 +156,7 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
// Returns shape of a ranked tensor.
// Precondition: output_val's is ranked tensor.
DenseElementsAttr GetShape(Value output_val) {
auto output_type = output_val->getType().cast<RankedTensorType>();
auto output_type = output_val.getType().cast<RankedTensorType>();
auto shape_vector = output_type.getShape();
std::vector<int32_t> shape(shape_vector.size());
for (int i = 0; i < shape_vector.size(); ++i) {
@ -152,7 +165,7 @@ DenseElementsAttr GetShape(Value output_val) {
return mlir::DenseElementsAttr::get(
RankedTensorType::get(
{static_cast<int>(shape.size())},
mlir::IntegerType::get(32, output_val->getContext())),
mlir::IntegerType::get(32, output_val.getContext())),
llvm::makeArrayRef(shape));
}
@ -173,13 +186,13 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
// Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs()->getDefiningOp());
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
if (!fc_op) return matchFailure();
Value filter = fc_op.filter();
Value bias = fc_op.bias();
ElementsAttr bias_value;
const bool is_none_bias = bias->getType().isa<NoneType>();
const bool is_none_bias = bias.getType().isa<NoneType>();
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
return matchFailure();
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
@ -213,7 +226,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op,
PatternRewriter &rewriter) const override {
Operation *input = relu_op.getOperand()->getDefiningOp();
Operation *input = relu_op.getOperand().getDefiningOp();
if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure();
auto fully_connected_op = cast<FullyConnectedOp>(input);
if (fully_connected_op.fused_activation_function() != "NONE")
@ -247,13 +260,13 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
// Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp());
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
if (!fc_op) return matchFailure();
Value filter = fc_op.filter();
Value bias = fc_op.bias();
ElementsAttr cst_tmp;
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
if (!bias->getType().isa<NoneType>() &&
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&cst_tmp)))
return matchFailure();
if (fc_op.fused_activation_function().equals("None")) return matchFailure();
@ -262,7 +275,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
// filter input. We only support broadcasting the operand along the depth
// dimension, when the operand's depth is 1.
Value new_const_val = constant_val;
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) {
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) {
auto original_shape = cst.getType().getShape();
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
original_shape.end());
@ -270,7 +283,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
auto new_cst = cst.reshape(RankedTensorType::get(
normalized_shape, cst.getType().getElementType()));
Type new_type = new_cst.getType();
if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) {
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
return matchFailure();
}
auto new_op =
@ -285,7 +298,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
auto new_filter =
rewriter.create<TF::MulOp>(loc, filter, new_const_val).z();
// If bias isn't None, it needs to be multiplied as well.
if (!bias->getType().isa<NoneType>()) {
if (!bias.getType().isa<NoneType>()) {
bias = rewriter.create<TF::MulOp>(loc, bias, constant_val).z();
}
@ -311,7 +324,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
PatternMatchResult matchAndRewrite(AffineOpType fc_op,
PatternRewriter &rewriter) const override {
// Binary op.
Operation *binary_op = fc_op.input()->getDefiningOp();
Operation *binary_op = fc_op.input().getDefiningOp();
if (!binary_op || binary_op->getNumOperands() != 2)
return this->matchFailure();
// We only handle the cases the RHS is a scalar.
@ -330,15 +343,15 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
DenseFPElementsAttr filter_cst, bias_cst;
if (!matchPattern(filter, m_Constant(&filter_cst))) {
// The filter maybe quantized, then we should set it to the real constant.
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter->getDefiningOp());
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
if (!dq) return this->matchFailure();
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input()->getDefiningOp());
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
return this->matchFailure();
}
filter = q.input();
}
if (!bias->getType().isa<NoneType>() &&
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&bias_cst)))
return this->matchFailure();
ShapedType filter_type = filter_cst.getType();
@ -362,7 +375,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
// The new bias should be a 1-D tensor with length equals to the bias
// dimension of the weight.
SmallVector<APFloat, 4> new_bias_values;
if (bias->getType().isa<NoneType>()) { // none bias, a list of zeros
if (bias.getType().isa<NoneType>()) { // none bias, a list of zeros
new_bias_values.resize(bias_size, APFloat(0.0));
} else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
@ -401,12 +414,12 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
// We recreate the constant op in case it is shared by the other ops. This
// might increase the model size.
auto new_filter_op = rewriter.create<ConstOp>(
fc_op.getLoc(), filter->getType(), new_filter);
fc_op.getLoc(), filter.getType(), new_filter);
fc_op.setOperand(0, binary_op->getOperand(0));
if (fc_op.filter() != filter) {
// This filter goes through quantize and dequantize ops. Then we just
// need to update the weight to the quantize op.
filter->replaceAllUsesWith(new_filter_op);
filter.replaceAllUsesWith(new_filter_op);
} else {
// This filter doesn't go through quantize and dequantize ops, Then
// we update the weight of the affine op directly.

View File

@ -17,15 +17,15 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {

View File

@ -55,7 +55,7 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0->hasOneUse()">>;
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// If we see a binary op (add, sub) op adding a constant value to a convolution
// op with constant bias, we can fuse the binary op into the convolution op by
@ -161,7 +161,7 @@ def EqualOperands : Constraint<CPred<"$0 == $1">>;
// Checks if the operand has rank == n
class OperandHasRank<int n> : Constraint<
CPred<"$0->getType().cast<ShapedType>().getRank() == " # n>>;
CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>;
// Matching HardSwish
def : Pat<
@ -255,8 +255,16 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
//===----------------------------------------------------------------------===//
// Binary ops patterns.
//===----------------------------------------------------------------------===//
def AreBroadcastableTypes : Constraint<CPred<
"TFL::IsBroadcastableElementsAttrAndType($0->getType(), $1->getType())">>;
"TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>;
def IsTailOfShape : Constraint<CPred<
"TFL::IsTailOfShape($0->getType(), $1->getType())">>;
def HaveSameType : Constraint<CPred<"$0->getType(), $1->getType()">>;
// Pattern for skipping Tile if it is mainly for broadcasting and the
// Op is already supporting broadcasting.
@ -272,13 +280,72 @@ multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
[(AreBroadcastableTypes $operand, $input)]>;
}
foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
in defm : FuseTileBroadcastIntoFollowingBinary<BroadcastingOp>;
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in {
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1])>;
}
}
foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
defm : FuseTileBroadcastIntoFollowingBinary<BinaryOp>;
// Instantiated FusedBinary patterns for the from-to pairs of ops.
defm : FusedBinaryActivationFuncOpPat<BinaryOp>;
// Move binary op before reshape: reshape -> binary => binary -> reshape.
// This is valid only when the binary operand is constant and the shape is the
// tail of the other operand and the intermediate result isn't used by other
// ops.
// $rhs is required to be the tail shape of $lhs, so after transformation the
// shape of the binary op result is valid. For example, assume the shapes of
// $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
// transformation, the shape of the binary op result is [40x1600], which
// couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
// make sure $rhs is the tail shape of $lhs.
def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs $a), TFL_AF_None),
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same.
[(IsTailOfShape $rhs, $lhs),
(HasOneUse $lhs),
// the two operands of the binary op is broadcastable
(AreBroadcastableTypes $rhs, $input)]>;
}
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp,
TFL_GreaterEqualOp] in {
// Move binary op before reshape: reshape -> binary => binary -> reshape.
// This is valid only when the binary operand is constant and the shape is the
// tail of the other operand and the intermediate result isn't used by other
// ops.
// $rhs is required to be the tail shape of $lhs, so after transformation the
// shape of the binary op result is valid. For example, assume the shapes of
// $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
// transformation, the shape of the binary op result is [40x1600], which
// couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
// make sure $rhs is the tail shape of $lhs.
def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs $a)),
(TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same.
[(IsTailOfShape $rhs, $lhs),
(HasOneUse $lhs),
// the two operands of the binary op is broadcastable
(AreBroadcastableTypes $rhs, $input)]>;
}
// Returns shape of a ranked tensor.
// if called without a ranked tensor it will fail.
def GetShape: NativeCodeCall<"GetShape($0)">;
// Convert squeeze to reshape
def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
(TFL_ReshapeOp $input,
(ConstantOp (GetShape $squeeze_op))),
@ -300,21 +367,6 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
(TFL_Relu1Op $input),
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in {
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1])>;
}
}
// Instantiated FusedBinary patterns for the from-to pairs of ops.
foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
TFL_MulOp, TFL_SubOp] in
defm : FusedBinaryActivationFuncOpPat<BinaryOps>;
// The constant folding in this pass might produce constant in the tf dialect.
// This rule is to legalize these constant to the tfl dialect.
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;

View File

@ -73,6 +73,10 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass();
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass();
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass();
// Creates an instance pass to add default quantization parameters.
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max);
} // namespace TFL
} // namespace mlir

View File

@ -16,8 +16,8 @@ limitations under the License.
// This transformation pass applies some clean up steps after quantization.
#include "llvm/Support/Casting.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -71,29 +71,29 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
auto remove_quantize_op = [&](QuantizeOp quantize_op) {
auto quantize_output = quantize_op.output();
auto quantize_type = quantize_output->getType();
auto quantize_type = quantize_output.getType();
input_types.push_back(quantize_type);
auto new_arg = bb.addArgument(quantize_type);
quantize_output->replaceAllUsesWith(new_arg);
quantize_output.replaceAllUsesWith(new_arg);
quantize_op.erase();
arg->dropAllUses();
arg.dropAllUses();
bb.eraseArgument(0);
};
// This is looking for a pattern: arg -> tfl.quantize
if (arg->hasOneUse() && llvm::isa<QuantizeOp>(*arg->user_begin())) {
auto quantize_op = llvm::cast<QuantizeOp>(*arg->user_begin());
if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
remove_quantize_op(quantize_op);
continue;
}
// Make a copy of current argument and append it to the end of the list if
// the pattern isn't found.
Type arg_type = arg->getType();
Type arg_type = arg.getType();
input_types.push_back(arg_type);
auto new_arg = bb.addArgument(arg_type);
arg->replaceAllUsesWith(new_arg);
arg->dropAllUses();
arg.replaceAllUsesWith(new_arg);
arg.dropAllUses();
bb.eraseArgument(0);
}
@ -103,15 +103,15 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
output_types.reserve(num_return_operands);
for (int i = 0; i != num_return_operands; ++i) {
auto returned_value = terminator->getOperand(i);
Operation* returned_op = returned_value->getDefiningOp();
Operation* returned_op = returned_value.getDefiningOp();
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
Value dequantized_result = dequantize_op.input();
output_types.push_back(dequantized_result->getType());
output_types.push_back(dequantized_result.getType());
terminator->setOperand(i, dequantized_result);
returned_op->erase();
} else {
output_types.push_back(returned_value->getType());
output_types.push_back(returned_value.getType());
}
}
auto new_func_type = builder.getFunctionType(input_types, output_types);

View File

@ -22,19 +22,19 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Identifier.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"

View File

@ -135,10 +135,10 @@ def : Pat<(TF_ReshapeOp
// Casts result type of $1 to a quantized type by using the quantization
// parameters from the type in $0.
class UpdateShapeWithAxis<int i> : NativeCodeCall<
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1->getType(), " # i # ")">;
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">;
class UsedBy<string op> : Constraint<
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0->getUsers().begin())">>;
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0.getUsers().begin())">>;
// When the op is passing-through, the output types of the quantized ops need
// to be updated as well. Since the quantize op manages its own type by the

View File

@ -21,10 +21,10 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
@ -153,7 +153,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
params);
auto dq_op =
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
arg->replaceAllUsesWith(dq_op.output());
arg.replaceAllUsesWith(dq_op.output());
q_op.setOperand(arg);
}
}
@ -161,8 +161,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
BlockArgument arg = func.getArgument(i);
auto* arg_block = arg->getOwner();
add_quantize_op(arg->getLoc(), arg->getType(), arg_block,
auto* arg_block = arg.getOwner();
add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
std::next(arg_block->begin(), i), arg, i);
}

View File

@ -38,17 +38,17 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -115,7 +115,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
PatternRewriter &rewriter) const override {
// We don't want to insert quantize/dequantize if the quantize op exists.
auto res = tf_op.outputs();
if (!res->hasOneUse() || isa<QuantizeOp>(*res->user_begin()))
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin()))
return this->matchFailure();
// Extract the min/max constant values from the operands. We also consider
@ -123,9 +123,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
// constants and the tf.FakeQuantWithMinMaxVarsOp.
Value min = tf_op.min(), max = tf_op.max();
DenseFPElementsAttr min_value, max_value;
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min->getDefiningOp()))
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
min = id1.input();
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max->getDefiningOp()))
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
max = id2.input();
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
@ -133,7 +133,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
int quant_dim = -1;
if (PerAxis) {
// This is a special case that the quant_dim is the last dimensions.
quant_dim = res->getType().template cast<ShapedType>().getRank() - 1;
quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
}
// Use the min/max from the operands and the num_bits and narrow_range
// attribute to create the quantization parameter for the new quantize op.
@ -155,7 +155,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
tf_op.getLoc(), qtype.getValue(), value, qtype);
auto dequantize = rewriter.create<TFL::DequantizeOp>(
tf_op.getLoc(), res_type, quantize.output());
value->replaceAllUsesWith(dequantize);
value.replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
return this->matchSuccess();
@ -240,7 +240,7 @@ struct ConvertTFConvOp : public RewritePattern {
// that we can extract info from the shape (e.g., for constructing bias
// tensor, for setting depth_multiplier attribute, etc.).
auto filter_type =
tf_op.filter()->getType().template dyn_cast<RankedTensorType>();
tf_op.filter().getType().template dyn_cast<RankedTensorType>();
if (filter_type && filter_type.getRank() == 4)
return matchSuccess(std::move(state));
@ -262,7 +262,7 @@ struct ConvertTFConvOp : public RewritePattern {
// Get a splat zero tensor with the expected dimension for the bias tensor
auto filter = tf_op.filter();
auto filter_type = filter->getType().template cast<RankedTensorType>();
auto filter_type = filter.getType().template cast<RankedTensorType>();
auto elem_type = filter_type.getElementType();
auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
filter_type.getShape());
@ -323,7 +323,7 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
// Create tensor type for the transpose result.
auto filter_type = filter->getType().cast<RankedTensorType>();
auto filter_type = filter.getType().cast<RankedTensorType>();
auto result_shape = functional::map(
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
perm);
@ -356,7 +356,7 @@ class ConvertTFDepthwiseConv2dNative
// have a corresponding 'depth_multiplier' attribute; the multiplier is the
// fourth dimension in the 4-D filter tensor. We query the multiplier from
// tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
auto multiplier = filter->getType().cast<RankedTensorType>().getDimSize(3);
auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
filter = legalizeFilter(rewriter, loc, filter);
return rewriter.create<TFL::DepthwiseConv2DOp>(
@ -380,7 +380,7 @@ class ConvertTFDepthwiseConv2dNative
/// RankedTensorType.
Value legalizeFilter(PatternRewriter &rewriter, Location loc,
Value filter) const {
auto filter_type = filter->getType().cast<RankedTensorType>();
auto filter_type = filter.getType().cast<RankedTensorType>();
auto filterShape = filter_type.getShape();
SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
filterShape[2] * filterShape[3]};
@ -432,11 +432,11 @@ struct ConvertTFStridedSlice : public RewritePattern {
// Insert a new reshape op.
Value original_input = strided_slice_op.input();
RankedTensorType original_input_type =
original_input->getType().cast<RankedTensorType>();
original_input.getType().cast<RankedTensorType>();
const ArrayRef<int64_t> &original_input_shape =
original_input_type.getShape();
RankedTensorType begin_type =
strided_slice_op.begin()->getType().cast<RankedTensorType>();
strided_slice_op.begin().getType().cast<RankedTensorType>();
const int dim_size = begin_type.getShape()[0];
SmallVector<int64_t, 4> new_shape;
int mask = 1;

View File

@ -19,17 +19,17 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"

View File

@ -18,24 +18,24 @@ limitations under the License.
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
@ -83,7 +83,7 @@ LogicalResult DuplicateValueIfNeeded(Operation* op,
// We can only clone the constant op at this point.
// Since all ops have been legalized to tflite ops, so we only care about
// ConstOp or QConstOp or mlir constant op/
Operation* input_op = operand->getDefiningOp();
Operation* input_op = operand.getDefiningOp();
if (input_op == nullptr) return failure();
Attribute attr;

View File

@ -20,13 +20,13 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Identifier.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
// The cmd line flag to specify the whitelist of functions. Rest are trimmed

View File

@ -24,17 +24,17 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -83,7 +83,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
template <typename BatchMatMulOpType>
std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
RankedTensorType tensorType = value->getType().cast<RankedTensorType>();
RankedTensorType tensorType = value.getType().cast<RankedTensorType>();
Type element_type = tensorType.getElementType();
int rank = tensorType.getShape().size();
@ -127,7 +127,7 @@ std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
template <typename BatchMatMulOpType>
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
Value value, Location loc, PatternRewriter& rewriter) {
auto value_type = value->getType().cast<RankedTensorType>();
auto value_type = value.getType().cast<RankedTensorType>();
auto shape = value_type.getShape();
int dims = shape.size();
@ -197,17 +197,17 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
Value input_lhs = op.x();
Value input_rhs = op.y();
if (!input_lhs->getType().isa<RankedTensorType>()) {
if (!input_lhs.getType().isa<RankedTensorType>()) {
// LHS must be a ranked tensor type
return this->matchFailure();
}
if (!input_rhs->getType().isa<RankedTensorType>()) {
if (!input_rhs.getType().isa<RankedTensorType>()) {
// RHS must be a ranked tensor type
return this->matchFailure();
}
auto lhs_type = input_lhs->getType().cast<RankedTensorType>();
auto rhs_type = input_rhs->getType().cast<RankedTensorType>();
auto lhs_type = input_lhs.getType().cast<RankedTensorType>();
auto rhs_type = input_rhs.getType().cast<RankedTensorType>();
auto element_type = lhs_type.getElementType();
@ -233,7 +233,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
if (op.adj_x()) {
input_lhs = createTransposeOp(input_lhs, loc, rewriter);
lhs_type = input_lhs->getType().cast<RankedTensorType>();
lhs_type = input_lhs.getType().cast<RankedTensorType>();
lhs_shape = lhs_type.getShape();
}
@ -241,7 +241,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
if (op.adj_y()) {
input_rhs = createTransposeOp(input_rhs, loc, rewriter);
rhs_type = input_rhs->getType().cast<RankedTensorType>();
rhs_type = input_rhs.getType().cast<RankedTensorType>();
rhs_shape = rhs_type.getShape();
}

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
#include "llvm/ADT/ArrayRef.h"
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/util/matmul_bcast.h"

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
namespace mlir {
namespace TFL {

View File

@ -19,7 +19,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
namespace mlir {
namespace TFL {

View File

@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/types.h"

Some files were not shown because too many files have changed in this diff Show More