commit
be6e1ce49a
4
.bazelrc
4
.bazelrc
@ -381,9 +381,9 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
|
||||
|
||||
build:rbe_win --config=rbe
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:toolchain"
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
|
||||
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_026:cc-toolchain-x64_windows"
|
||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:cc-toolchain-x64_windows"
|
||||
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
|
||||
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win_1803:rbe_windows_1803"
|
||||
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win_1803:windows_jdk8"
|
||||
|
@ -1 +1 @@
|
||||
1.1.0
|
||||
1.2.1
|
||||
|
15
README.md
15
README.md
@ -29,20 +29,6 @@ to
|
||||
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
|
||||
See all the [mailing lists](https://www.tensorflow.org/community/forums).
|
||||
|
||||
## Feature Prioritization Survey
|
||||
|
||||
The TensorFlow team is working on building/improving features, and understands
|
||||
that it is very important to prioritize these efforts based on what TF users
|
||||
need.
|
||||
|
||||
The goal of this short, < 5 minute
|
||||
[survey](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad), is to help
|
||||
the TensorFlow team better understand what features to prioritize based on your
|
||||
feedback. Participation is of course optional.
|
||||
|
||||
Take the survey
|
||||
[HERE](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad).
|
||||
|
||||
## Install
|
||||
|
||||
See the [TensorFlow install guide](https://www.tensorflow.org/install) for the
|
||||
@ -164,4 +150,3 @@ Learn more about the
|
||||
## License
|
||||
|
||||
[Apache License 2.0](LICENSE)
|
||||
|
||||
|
16
configure.py
16
configure.py
@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '1.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '1.1.0'
|
||||
_TF_MIN_BAZEL_VERSION = '1.2.1'
|
||||
_TF_MAX_BAZEL_VERSION = '1.2.1'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||
@ -147,14 +147,16 @@ def write_action_env_to_bazelrc(var_name, var):
|
||||
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
|
||||
|
||||
|
||||
def run_shell(cmd, allow_non_zero=False):
|
||||
def run_shell(cmd, allow_non_zero=False, stderr=None):
|
||||
if stderr is None:
|
||||
stderr = sys.stdout
|
||||
if allow_non_zero:
|
||||
try:
|
||||
output = subprocess.check_output(cmd)
|
||||
output = subprocess.check_output(cmd, stderr=stderr)
|
||||
except subprocess.CalledProcessError as e:
|
||||
output = e.output
|
||||
else:
|
||||
output = subprocess.check_output(cmd)
|
||||
output = subprocess.check_output(cmd, stderr=stderr)
|
||||
return output.decode('UTF-8').strip()
|
||||
|
||||
|
||||
@ -169,10 +171,12 @@ def get_python_path(environ_cp, python_bin_path):
|
||||
if environ_cp.get('PYTHONPATH'):
|
||||
python_paths = environ_cp.get('PYTHONPATH').split(':')
|
||||
try:
|
||||
stderr = open(os.devnull, 'wb')
|
||||
library_paths = run_shell([
|
||||
python_bin_path, '-c',
|
||||
'import site; print("\\n".join(site.getsitepackages()))'
|
||||
]).split('\n')
|
||||
],
|
||||
stderr=stderr).split('\n')
|
||||
except subprocess.CalledProcessError:
|
||||
library_paths = [
|
||||
run_shell([
|
||||
|
@ -458,7 +458,7 @@ static void TF_Run_Helper(
|
||||
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
|
||||
continue;
|
||||
}
|
||||
c_outputs[i] = TF_TensorFromTensor(src, status);
|
||||
c_outputs[i] = TF_TensorFromTensor(src, &status->status);
|
||||
if (!status->status.ok()) return;
|
||||
}
|
||||
}
|
||||
@ -1493,7 +1493,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
|
||||
Tensor t;
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
||||
if (!status->status.ok()) return;
|
||||
*value = TF_TensorFromTensor(t, status);
|
||||
*value = TF_TensorFromTensor(t, &status->status);
|
||||
}
|
||||
|
||||
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||
@ -1504,7 +1504,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||
if (!status->status.ok()) return;
|
||||
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
||||
for (int i = 0; i < len; ++i) {
|
||||
values[i] = TF_TensorFromTensor(ts[i], status);
|
||||
values[i] = TF_TensorFromTensor(ts[i], &status->status);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2398,7 +2398,7 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
|
||||
graph->graph.versions().producer(), &evaluated, &result_tensor);
|
||||
if (evaluated) {
|
||||
DCHECK(status->status.ok());
|
||||
*result = TF_TensorFromTensor(result_tensor, status);
|
||||
*result = TF_TensorFromTensor(result_tensor, &status->status);
|
||||
if (!status->status.ok()) evaluated = false;
|
||||
}
|
||||
return evaluated;
|
||||
|
@ -634,7 +634,7 @@ TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
|
||||
std::unique_ptr<tensorflow::Tensor> tensor;
|
||||
reader->GetTensor(name, &tensor, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::TF_TensorFromTensor(*tensor, status);
|
||||
return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
|
||||
}
|
||||
|
||||
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
|
||||
|
@ -188,7 +188,7 @@ namespace tensorflow {
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
||||
TF_Buffer* out);
|
||||
|
@ -51,7 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
|
||||
namespace tensorflow {
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
namespace {
|
||||
@ -227,7 +227,7 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
|
||||
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
const tensorflow::int64 n = data.size();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
Status status;
|
||||
for (const std::vector<tensorflow::int64>& dims :
|
||||
std::vector<std::vector<tensorflow::int64>>{
|
||||
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
|
||||
@ -236,8 +236,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
|
||||
src.flat<tstring>()(i) = data[i];
|
||||
}
|
||||
TF_Tensor* dst = TF_TensorFromTensor(src, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_Tensor* dst = TF_TensorFromTensor(src, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.error_message();
|
||||
|
||||
// Convert back to a C++ Tensor and ensure we get expected output.
|
||||
Tensor output;
|
||||
@ -249,7 +249,6 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
|
||||
TF_DeleteTensor(dst);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorEncodeDecodeStrings) {
|
||||
@ -1394,8 +1393,9 @@ TEST(CAPI, SavedModel) {
|
||||
TF_Operation* input_op =
|
||||
TF_GraphOperationByName(graph, input_op_name.c_str());
|
||||
ASSERT_TRUE(input_op != nullptr);
|
||||
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
Status status;
|
||||
csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
|
||||
ASSERT_TRUE(status.ok()) << status.error_message();
|
||||
|
||||
const tensorflow::string output_op_name(
|
||||
tensorflow::ParseTensorName(output_name).first);
|
||||
@ -2522,12 +2522,11 @@ TEST(CAPI, TestTensorIsNotAligned) {
|
||||
|
||||
// Take an unaligned slice.
|
||||
Tensor y = x.Slice(1, 13);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* a = TF_TensorFromTensor(y, status);
|
||||
Status status;
|
||||
TF_Tensor* a = TF_TensorFromTensor(y, &status);
|
||||
if (EIGEN_MAX_ALIGN_BYTES > 0) {
|
||||
EXPECT_FALSE(TF_TensorIsAligned(a));
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
TF_DeleteTensor(a);
|
||||
}
|
||||
|
||||
|
@ -464,7 +464,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
&new_remote_device_mgr));
|
||||
remote_device_mgr = new_remote_device_mgr.get();
|
||||
} else {
|
||||
ctx->context->ClearCaches();
|
||||
ctx->context->ClearCachesAndDefaultExecutor();
|
||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||
// tensors for removed / replaced workers.
|
||||
|
||||
@ -638,7 +638,7 @@ tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
|
||||
|
||||
void OpInferSingleTypeInputListAttrs(TFE_Op* op,
|
||||
const tensorflow::OpDef::ArgDef& input_def,
|
||||
TFE_TensorHandle** inputs,
|
||||
const tensorflow::DataType dtype,
|
||||
int num_inputs) {
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
|
||||
@ -646,26 +646,20 @@ void OpInferSingleTypeInputListAttrs(TFE_Op* op,
|
||||
ictx->attrs.insert(input_def.number_attr());
|
||||
}
|
||||
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
|
||||
op->operation.MutableAttrs()->Set(input_def.type_attr(),
|
||||
inputs[0]->handle->dtype);
|
||||
op->operation.MutableAttrs()->Set(input_def.type_attr(), dtype);
|
||||
ictx->attrs.insert(input_def.type_attr());
|
||||
}
|
||||
}
|
||||
|
||||
void OpInferMixedTypeInputListAttrs(TFE_Op* op,
|
||||
const tensorflow::OpDef::ArgDef& input_def,
|
||||
TFE_TensorHandle** inputs, int num_inputs) {
|
||||
void OpInferMixedTypeInputListAttrs(
|
||||
TFE_Op* op, const tensorflow::OpDef::ArgDef& input_def,
|
||||
const std::vector<tensorflow::DataType>& dtypes) {
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
|
||||
std::unique_ptr<tensorflow::DataType[]> dtypes(
|
||||
new tensorflow::DataType[num_inputs]);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
dtypes[i] = inputs[i]->handle->dtype;
|
||||
}
|
||||
op->operation.MutableAttrs()->Set(
|
||||
input_def.type_list_attr(),
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
|
||||
num_inputs));
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.data(),
|
||||
dtypes.size()));
|
||||
ictx->attrs.insert(input_def.type_list_attr());
|
||||
}
|
||||
}
|
||||
@ -675,10 +669,15 @@ tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
|
||||
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
|
||||
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
|
||||
if (!input_def.type_list_attr().empty()) {
|
||||
OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
|
||||
std::vector<tensorflow::DataType> dtypes(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
dtypes[i] = inputs[i]->handle->dtype;
|
||||
}
|
||||
OpInferMixedTypeInputListAttrs(op, input_def, dtypes);
|
||||
} else if (!input_def.type_attr().empty() &&
|
||||
!input_def.number_attr().empty()) {
|
||||
OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
|
||||
OpInferSingleTypeInputListAttrs(op, input_def, inputs[0]->handle->dtype,
|
||||
num_inputs);
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument("Invalid input list definition");
|
||||
}
|
||||
@ -754,7 +753,9 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
return list;
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); }
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
ctx->context->ClearCachesAndThreadExecutors();
|
||||
}
|
||||
|
||||
// Set server_def on the context, possibly updating it.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
@ -990,7 +991,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
h_cpu->Unref();
|
||||
return nullptr;
|
||||
}
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, &status->status);
|
||||
h_cpu->Unref();
|
||||
return retval;
|
||||
} else {
|
||||
@ -1006,7 +1007,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -206,7 +206,7 @@ typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
|
||||
// error and nullptr is returned. This function can block till the operation
|
||||
// that produces `handle` has completed.
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* handle, TF_Status* status);
|
||||
TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
// Deletes `debug_info`.
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
|
||||
|
@ -50,15 +50,15 @@ std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
|
||||
extern "C" {
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* handle, TF_Status* status) {
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = handle->handle->Tensor(&tensor);
|
||||
status->status = h->handle->Tensor(&tensor);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Device* device = handle->handle->device();
|
||||
tensorflow::Device* device = h->handle->device();
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
tensorflow::XlaDevice* xla_device =
|
||||
@ -72,7 +72,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
return nullptr;
|
||||
}
|
||||
if (VLOG_IS_ON(3)) {
|
||||
std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status);
|
||||
std::vector<int64> shape_to_log = TensorShapeAsVector(h, status);
|
||||
if (!status->status.ok()) {
|
||||
// Ignore the status here as we are simply logging.
|
||||
status->status = tensorflow::Status::OK();
|
||||
@ -138,7 +138,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
|
||||
// If the tensor is not an XLA tensor, the device shape is
|
||||
// the same as regular tensor shape.
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(h, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -27,11 +27,18 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
// TODO(b/143949264): Testing is not yet supported on Windows. Will implement
|
||||
// testing on Windows when implementing modular filesystems on Windows.
|
||||
#if defined(PLATFORM_WINDOWS)
|
||||
#error Windows is not yet supported. Need mkdir().
|
||||
#endif
|
||||
// Make mkdir resolve to _mkdir to create the test temporary directory.
|
||||
#include <direct.h>
|
||||
#define mkdir(name, mode) _mkdir(name)
|
||||
|
||||
// Windows defines the following macros to convert foo to fooA or fooW,
|
||||
// depending on the type of the string argument. We don't use these macros, so
|
||||
// undefine them here.
|
||||
#undef LoadLibrary
|
||||
#undef CopyFile
|
||||
#undef DeleteFile
|
||||
#endif // defined(PLATFORM_WINDOWS)
|
||||
|
||||
// The tests defined here test the compliance of filesystems with the API
|
||||
// defined by `filesystem_interface.h`.
|
||||
@ -86,9 +93,6 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
// TODO(b/143949264): Testing is not yet supported on Windows. Will
|
||||
// implement testing on Windows when implementing modular filesystems on
|
||||
// Windows.
|
||||
if (mkdir(root_dir_.c_str(), 0755) != 0) {
|
||||
int error_code = errno;
|
||||
GTEST_SKIP() << "Cannot create working directory: "
|
||||
@ -1668,30 +1672,40 @@ static std::vector<std::string>* SchemeVector() {
|
||||
return schemes;
|
||||
}
|
||||
|
||||
static std::vector<std::string> GetSchemes() {
|
||||
std::vector<std::string>* user_schemes = SchemeVector();
|
||||
std::vector<std::string> all_schemes;
|
||||
// `INSTANTIATE_TEST_SUITE_P` is called once for every `TEST_P`. However, we
|
||||
// only want to analyze the user provided schemes and those that are registered
|
||||
// only once. Hence, this function keeping another static pointer to a vector
|
||||
// which contains only the schemes under test.
|
||||
//
|
||||
// Without this additional step, when there are schemes available but the user
|
||||
// only requests schemes that don't exist, first instantiation of the test would
|
||||
// filter out all the user provided schemes (as they are not registered) but
|
||||
// subsequent instantiations would return all registered schemes (since the
|
||||
// vector with the user provided schemes is cleared).
|
||||
static std::vector<std::string>* GetSchemesFromUserOrEnv() {
|
||||
std::vector<std::string>* all_schemes = new std::vector<std::string>;
|
||||
tensorflow::Status status =
|
||||
tensorflow::Env::Default()->GetRegisteredFileSystemSchemes(&all_schemes);
|
||||
tensorflow::Env::Default()->GetRegisteredFileSystemSchemes(all_schemes);
|
||||
|
||||
if (status.ok()) {
|
||||
std::vector<std::string>* user_schemes = SchemeVector();
|
||||
if (!user_schemes->empty()) {
|
||||
auto is_registered_scheme = [&all_schemes](const auto& scheme) {
|
||||
return std::find(all_schemes.begin(), all_schemes.end(), scheme) ==
|
||||
all_schemes.end();
|
||||
auto is_requested_scheme = [user_schemes](const auto& scheme) {
|
||||
return std::find(user_schemes->begin(), user_schemes->end(), scheme) ==
|
||||
user_schemes->end();
|
||||
};
|
||||
auto end = std::remove_if(user_schemes->begin(), user_schemes->end(),
|
||||
is_registered_scheme);
|
||||
user_schemes->erase(end, user_schemes->end());
|
||||
return *user_schemes;
|
||||
auto end = std::remove_if(all_schemes->begin(), all_schemes->end(),
|
||||
is_requested_scheme);
|
||||
all_schemes->erase(end, all_schemes->end());
|
||||
}
|
||||
}
|
||||
|
||||
// Next, try all schemes available
|
||||
if (!all_schemes.empty()) return all_schemes;
|
||||
}
|
||||
return all_schemes;
|
||||
}
|
||||
|
||||
// Fallback: no filesystems present, hence no tests
|
||||
return std::vector<std::string>();
|
||||
static std::vector<std::string> GetSchemes() {
|
||||
static std::vector<std::string>* schemes = GetSchemesFromUserOrEnv();
|
||||
return *schemes;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ModularFileSystem, ModularFileSystemTest,
|
||||
|
@ -1,35 +1,47 @@
|
||||
# Experimental posix filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Although this target results in a shared object that will be loaded at
|
||||
# runtime, this target must be a `cc_library` instead of a `cc_binary`. Making
|
||||
# it a `cc_binary` requires `linkshared = True`. In turn, this brings in several
|
||||
# TensorFlow symbols under `tensorflow::` namespace, for which we have no ABI
|
||||
# guarantees. Hence, in order to maintain ABI compatibility, this is marked as a
|
||||
# `cc_library` for now and we will revisit in the future.
|
||||
# TODO(mihaimaruseac): Determine if `cc_binary` makes more sense (when all
|
||||
# filesystems are converted and BUILD files are refactored to be modular).
|
||||
# TODO(b/144585140): The helpers should be separated into a different BUILD target
|
||||
# but doing that would result in symbols not being visible when loading plugin.
|
||||
# Revisit this once POSIX filesystem completely lands. See also the other TODO.
|
||||
# This also has the unfortunate effect that both versions of copy_file get
|
||||
# compiled, regardless of which one actually gets used!
|
||||
# Filesystem implementation for POSIX environments: Linux, MacOS, Android, etc.
|
||||
tf_cc_shared_object(
|
||||
name = "libposix_filesystem.so",
|
||||
framework_so = [],
|
||||
linkstatic = False,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":posix_filesystem_impl"],
|
||||
)
|
||||
|
||||
# The real implementation of the filesystem.
|
||||
cc_library(
|
||||
name = "posix_filesystem",
|
||||
srcs = [
|
||||
"posix_filesystem.cc",
|
||||
"posix_filesystem_helper.cc",
|
||||
"posix_filesystem_helper.h",
|
||||
"copy_file.h",
|
||||
] + select({
|
||||
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
|
||||
"//conditions:default": ["copy_file_portable.cc"],
|
||||
}),
|
||||
name = "posix_filesystem_impl",
|
||||
srcs = ["posix_filesystem.cc"],
|
||||
deps = [
|
||||
":posix_filesystem_helper",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
],
|
||||
)
|
||||
|
||||
# Library implementing helper functionality, so that the above only contains
|
||||
# the API implementation for modular filesystems.
|
||||
cc_library(
|
||||
name = "posix_filesystem_helper",
|
||||
srcs = ["posix_filesystem_helper.cc"],
|
||||
hdrs = ["posix_filesystem_helper.h"],
|
||||
deps = [":copy_file"],
|
||||
)
|
||||
|
||||
# On Linux, we can copy files faster using `sendfile`. But not elsewhere.
|
||||
# Hence, this private library to select which implementation to use.
|
||||
cc_library(
|
||||
name = "copy_file",
|
||||
srcs = select({
|
||||
"//tensorflow:linux_x86_64": ["copy_file_linux.cc"],
|
||||
"//conditions:default": ["copy_file_portable.cc"],
|
||||
}),
|
||||
hdrs = ["copy_file.h"],
|
||||
)
|
||||
|
@ -181,7 +181,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
|
||||
return;
|
||||
}
|
||||
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
|
||||
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
|
||||
TF_Tensor* result =
|
||||
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
*tensor = result;
|
||||
}
|
||||
|
@ -170,6 +170,11 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
void StringEncode(const char* src, size_t src_len, char* dst) {
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
memcpy(dst, src, src_len);
|
||||
}
|
||||
|
||||
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||
size_t dst_len, TF_Status* status) {
|
||||
const size_t sz = TF_StringEncodedSize(src_len);
|
||||
@ -185,8 +190,7 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||
src_len, "-byte string"));
|
||||
return 0;
|
||||
}
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
memcpy(dst, src, src_len);
|
||||
StringEncode(src, src_len, dst);
|
||||
return sz;
|
||||
}
|
||||
|
||||
@ -245,13 +249,11 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype,
|
||||
namespace tensorflow {
|
||||
|
||||
// Non-static for testing.
|
||||
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
|
||||
*status = tensorflow::Status::OK();
|
||||
if (!src.IsInitialized()) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, FailedPrecondition(
|
||||
"attempt to use a tensor with an uninitialized value"));
|
||||
*status = FailedPrecondition(
|
||||
"attempt to use a tensor with an uninitialized value");
|
||||
return nullptr;
|
||||
}
|
||||
if (src.NumElements() == 0) {
|
||||
@ -259,14 +261,13 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
}
|
||||
if (src.dtype() == tensorflow::DT_RESOURCE) {
|
||||
if (src.shape().dims() != 0) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument(
|
||||
*status = InvalidArgument(
|
||||
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
||||
src.shape().DebugString(),
|
||||
"). Please file a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues/new, "
|
||||
"ideally with a "
|
||||
"short code snippet that reproduces this error."));
|
||||
"short code snippet that reproduces this error.");
|
||||
return nullptr;
|
||||
}
|
||||
const string str =
|
||||
@ -305,23 +306,15 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
*offsets = (dst - data_start);
|
||||
offsets++;
|
||||
const string& s = srcarray(i);
|
||||
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
Set_TF_Status_from_Status(
|
||||
status,
|
||||
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
|
||||
srcarray.size(), "): ", TF_Message(status)));
|
||||
delete[] base;
|
||||
return nullptr;
|
||||
}
|
||||
const size_t consumed = TF_StringEncodedSize(s.size());
|
||||
StringEncode(s.data(), s.size(), dst);
|
||||
dst += consumed;
|
||||
dst_len -= consumed;
|
||||
}
|
||||
if (dst != base + size) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument(
|
||||
*status = InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes"));
|
||||
" bytes, but the tensor is encoded in ", size, " bytes");
|
||||
delete[] base;
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -259,6 +259,9 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
|
||||
// TODO(rocm):
|
||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||
@ -271,6 +274,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||
SetRandomValuesForMaxPooling<float>(&x_init_value);
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||
TensorShape x_shape({1, 2, 2, 1});
|
||||
@ -283,6 +287,9 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
|
||||
// TODO(rocm):
|
||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||
@ -293,6 +300,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(NNGradTest, LRN) {
|
||||
TensorShape x_shape({1, 1, 2, 1});
|
||||
|
@ -75,8 +75,8 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support", # fixdeps: keep
|
||||
"@llvm//:x86_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:support", # fixdeps: keep
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
@ -104,11 +104,11 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:aarch64_code_gen", # fixdeps: keep
|
||||
"@llvm//:arm_code_gen", # fixdeps: keep
|
||||
"@llvm//:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm//:target",
|
||||
"@llvm//:x86_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
@ -205,9 +205,9 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm//:core",
|
||||
"@llvm//:support",
|
||||
"@llvm//:target",
|
||||
"@llvm-project//llvm:core",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:target",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -407,6 +407,7 @@ def target_llvm_triple():
|
||||
"//tensorflow:android_arm64": "aarch64-none-android",
|
||||
"//tensorflow:android_x86": "i686-none-android",
|
||||
"//tensorflow:ios": "arm64-none-ios",
|
||||
"//tensorflow:ios_x86_64": "x86_64-apple-ios",
|
||||
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
||||
"//tensorflow:macos": "x86_64-none-darwin",
|
||||
"//conditions:default": "x86_64-pc-linux",
|
||||
|
@ -4,12 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilati
|
||||
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
":internal",
|
||||
# BEGIN-GOOGLE-INTERNAL
|
||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||
# END-GOOGLE-INTERNAL
|
||||
],
|
||||
default_visibility = [":internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
@ -500,6 +495,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/graph_node_util.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph_node_util.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
|
@ -2130,6 +2130,53 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CopyOutsideCompilationConstNodes(
|
||||
Graph* g, const string& outside_compilation_attr_name) {
|
||||
for (Node* n : g->op_nodes()) {
|
||||
if (!n->IsConstant() ||
|
||||
!HasNodeAttr(n->def(), outside_compilation_attr_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<const Edge*> out_edges(n->out_edges().begin(),
|
||||
n->out_edges().end());
|
||||
bool has_non_oc_output = false;
|
||||
for (const Edge* e : out_edges) {
|
||||
if (!e->IsControlEdge() &&
|
||||
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
|
||||
has_non_oc_output = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!has_non_oc_output) {
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeDef copy_def = n->def();
|
||||
copy_def.set_name(g->NewName(n->name()));
|
||||
copy_def.mutable_attr()->erase(outside_compilation_attr_name);
|
||||
Status s;
|
||||
Node* copy_node = g->AddNode(copy_def, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
g->AddControlEdge(e->src(), copy_node);
|
||||
}
|
||||
}
|
||||
for (const Edge* e : out_edges) {
|
||||
if (!e->IsControlEdge() &&
|
||||
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
|
||||
Node* dst = e->dst();
|
||||
int dst_input = e->dst_input();
|
||||
g->RemoveEdge(e);
|
||||
g->AddEdge(copy_node, 0, dst, dst_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status RewriteOutsideCompilationSubgraphFn::operator()(
|
||||
@ -2279,6 +2326,10 @@ Status ExtractOutsideCompilationForFunction(
|
||||
std::vector<string> outside_compilation_host_graphs;
|
||||
std::vector<string> shape_inference_graphs_to_rewrite;
|
||||
if (*has_outside_compilation) {
|
||||
// Copy outside compilation Const nodes with non outside compilation users.
|
||||
TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
|
||||
fbody->graph, outside_compilation_attr_name));
|
||||
|
||||
// Find dependencies between outside compilation clusters.
|
||||
TF_ASSIGN_OR_RETURN(auto cluster_deps,
|
||||
OutsideCompilationClusterDependencies(
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/node_matchers.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
@ -24,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/graph/graph_node_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace testing {
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/graph/graph_node_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
|
@ -17,7 +17,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
@ -39,7 +42,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
|
||||
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
|
||||
}
|
||||
|
||||
Status PropagateShapes(const Graph& graph,
|
||||
Status PropagateShapes(Graph* graph,
|
||||
const std::map<int, InferredShape>& arg_shapes,
|
||||
const std::vector<BackEdgeHelper::BackEdge>& back_edges,
|
||||
ShapeRefiner* shape_refiner) {
|
||||
@ -54,7 +57,7 @@ Status PropagateShapes(const Graph& graph,
|
||||
// shapes.
|
||||
// TODO(phawkins): handle cyclic graphs.
|
||||
std::vector<Node*> order;
|
||||
GetReversePostOrder(graph, &order);
|
||||
GetReversePostOrder(*graph, &order);
|
||||
|
||||
for (Node* n : order) {
|
||||
// Ignore the status returned by the shape_refiner. We want the best effort
|
||||
@ -99,6 +102,67 @@ Status PropagateShapes(const Graph& graph,
|
||||
}
|
||||
}
|
||||
|
||||
// Sometimes we have VariableShape nodes in while loop (after Enter nodes).
|
||||
// They won't be constant-folded because TensorFlow constant folding does
|
||||
// not handle Enter nodes (and thus does not handle any nodes after Enter
|
||||
// nodes). We try to replace such VariableShape nodes with Const nodes here.
|
||||
if (n->type_string() == "VariableShape") {
|
||||
shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
|
||||
auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
|
||||
if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
|
||||
shape_inference::ShapeHandle handle =
|
||||
handle_shapes_and_types->at(0).shape;
|
||||
TensorShapeProto shape_proto;
|
||||
context->ShapeHandleToProto(handle, &shape_proto);
|
||||
if (!shape_proto.unknown_rank()) {
|
||||
NodeDef const_def;
|
||||
const_def.set_op("Const");
|
||||
Node* var_node;
|
||||
TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
|
||||
const_def.set_name(
|
||||
graph->NewName(absl::StrCat("var_shape_", var_node->name())));
|
||||
DataType dtype = n->output_type(0);
|
||||
AddNodeAttr("dtype", dtype, &const_def);
|
||||
TensorProto value;
|
||||
value.set_dtype(dtype);
|
||||
value.mutable_tensor_shape()->add_dim()->set_size(
|
||||
shape_proto.dim_size());
|
||||
for (const auto& dim : shape_proto.dim()) {
|
||||
if (dtype == DT_INT32) {
|
||||
value.add_int_val(dim.size());
|
||||
} else {
|
||||
value.add_int64_val(dim.size());
|
||||
}
|
||||
}
|
||||
AddNodeAttr("value", value, &const_def);
|
||||
for (auto const& attr : n->attrs()) {
|
||||
if (*attr.first.begin() == '_') {
|
||||
AddNodeAttr(attr.first, attr.second, &const_def);
|
||||
}
|
||||
}
|
||||
|
||||
Status s;
|
||||
Node* const_node = graph->AddNode(const_def, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
|
||||
graph->AddControlEdge(var_node, const_node);
|
||||
std::vector<const Edge*> out_edges(n->out_edges().begin(),
|
||||
n->out_edges().end());
|
||||
for (const Edge* e : out_edges) {
|
||||
if (e->IsControlEdge()) {
|
||||
graph->AddControlEdge(const_node, e->dst());
|
||||
graph->RemoveEdge(e);
|
||||
} else {
|
||||
Node* dst = e->dst();
|
||||
int dst_input = e->dst_input();
|
||||
graph->RemoveEdge(e);
|
||||
graph->AddEdge(const_node, 0, dst, dst_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge node causes a loop so we remove NextIteration->Merge edge before
|
||||
// performing shape inference. But removing those edges also prevents us
|
||||
// from inferring output shape for Merge node (we need shapes for all its
|
||||
@ -196,7 +260,7 @@ Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
|
||||
// the shape inference is complete.
|
||||
BackEdgeHelper back_edge;
|
||||
TF_RETURN_IF_ERROR(back_edge.Remove(graph));
|
||||
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes,
|
||||
TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
|
||||
back_edge.RemovedEdges(), &shape_refiner));
|
||||
TF_RETURN_IF_ERROR(back_edge.Replace());
|
||||
|
||||
|
@ -191,7 +191,7 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
|
||||
data::IteratorGetNextAsOptionalOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \
|
||||
data::IteratorGetNextSyncOp); \
|
||||
data::IteratorGetNextOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("string_handle"), \
|
||||
|
@ -6,7 +6,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/tf2xla:__subpackages__",
|
||||
"@local_config_mlir//:friends",
|
||||
"@llvm-project//mlir:friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
@ -30,8 +30,8 @@ cc_library(
|
||||
hdrs = ["op_or_arg_name_mapper.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
@ -43,11 +43,14 @@ cc_library(
|
||||
":passes",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:MlirOptLib",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//test:TestTransforms",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AffineDialectRegistration",
|
||||
"@llvm-project//mlir:LoopDialectRegistration",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir/test:TestTransforms",
|
||||
],
|
||||
)
|
||||
|
||||
@ -80,9 +83,10 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
"@local_config_mlir//:AffineDialectRegistration",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
||||
"//tensorflow/compiler/mlir/xla:xla_test_passes",
|
||||
"@llvm-project//mlir:AffineOps",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
)
|
||||
|
||||
@ -92,7 +96,7 @@ cc_library(
|
||||
hdrs = ["init_mlir.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -122,11 +126,11 @@ tf_cc_binary(
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:TranslateClParser",
|
||||
"@local_config_mlir//:Translation",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TranslateClParser",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
# MLIR dialects and utilities for TensorFlow, TensorFlow Lite and XLA.
|
||||
|
||||
This module contains the MLIR
|
||||
([Multi-Level Intermediate Representation](https://github.com/tensorflow/mlir))
|
||||
([Multi-Level Intermediate Representation](https://mlir.llvm.org))
|
||||
dialects and utilities for
|
||||
|
||||
1. TensorFlow
|
||||
2. XLA
|
||||
3. TF Lite
|
||||
|
||||
See [MLIR repo](https://github.com/tensorflow/mlir) for complete documentation.
|
||||
See [MLIR's website](https://mlir.llvm.org) for complete documentation.
|
||||
|
@ -10,7 +10,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
|
||||
# Default values used by the test runner.
|
||||
_default_test_file_exts = ["mlir", ".pbtxt", ".td"]
|
||||
_default_driver = "@local_config_mlir//:run_lit.sh"
|
||||
_default_driver = "@llvm-project//mlir:run_lit.sh"
|
||||
_default_size = "small"
|
||||
_default_tags = ["no_rocm"]
|
||||
|
||||
@ -50,16 +50,16 @@ def _run_lit_test(name, data, size, tags, driver, features):
|
||||
|
||||
native.py_test(
|
||||
name = name,
|
||||
srcs = ["@llvm//:lit"],
|
||||
srcs = ["@llvm-project//llvm:lit"],
|
||||
tags = tags,
|
||||
args = [
|
||||
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
|
||||
] + features,
|
||||
data = data + [
|
||||
"//tensorflow/compiler/mlir:litfiles",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:count",
|
||||
"@llvm//:not",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:count",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
size = size,
|
||||
main = "lit.py",
|
||||
|
@ -1,6 +1,6 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
|
||||
load(
|
||||
"@local_config_mlir//:tblgen.bzl",
|
||||
"//third_party/mlir:tblgen.bzl",
|
||||
"gentbl",
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ package(
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["@local_config_mlir//:subpackages"],
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//learning/brain/google/xla/...",
|
||||
@ -28,7 +28,7 @@ filegroup(
|
||||
srcs = [
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@local_config_mlir//:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
@ -48,7 +48,7 @@ gentbl(
|
||||
"g3doc/tfl_ops.md",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
@ -63,11 +63,11 @@ gentbl(
|
||||
"transforms/generated_prepare_tf.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/prepare_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_optimize_td_files",
|
||||
],
|
||||
@ -81,11 +81,11 @@ gentbl(
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/tensorlist_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
@ -98,11 +98,11 @@ gentbl(
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/legalize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
@ -115,11 +115,11 @@ gentbl(
|
||||
"transforms/generated_optimize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/optimize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
@ -132,11 +132,11 @@ gentbl(
|
||||
"transforms/generated_quantize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/quantize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
@ -148,11 +148,11 @@ gentbl(
|
||||
"transforms/generated_post_quantize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/post_quantize_patterns.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
@ -165,9 +165,9 @@ cc_library(
|
||||
"utils/validators.h",
|
||||
],
|
||||
deps = [
|
||||
"@local_config_mlir//:Dialect",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
)
|
||||
|
||||
@ -185,21 +185,21 @@ cc_library(
|
||||
"transforms/passes.h",
|
||||
"utils/attribute_utils.h",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
||||
"@local_config_mlir//:include/mlir/Transforms/InliningUtils.h",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:Dialect",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -216,10 +216,10 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -233,9 +233,9 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
)
|
||||
|
||||
@ -248,10 +248,10 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -292,14 +292,14 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Transforms",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -317,12 +317,12 @@ cc_library(
|
||||
":tensorflow_lite",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -330,6 +330,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tensorflow_lite_quantize",
|
||||
srcs = [
|
||||
"transforms/default_quant_params.cc",
|
||||
"transforms/generated_post_quantize.inc",
|
||||
"transforms/generated_quantize.inc",
|
||||
"transforms/load_quantization_recipe.cc",
|
||||
@ -348,13 +349,13 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -376,7 +377,7 @@ genrule(
|
||||
"utils/generated_op_quant_spec_getters.inc",
|
||||
],
|
||||
cmd = ("$(location //tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen) " +
|
||||
"-I external/local_config_mlir/include " +
|
||||
"-I external/llvm-project/mlir/include " +
|
||||
"-I external/org_tensorflow " +
|
||||
"$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"),
|
||||
tools = ["//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen"],
|
||||
@ -390,7 +391,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@local_config_mlir//:IR",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -401,9 +402,9 @@ tf_native_cc_binary(
|
||||
"operator_converter_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
"@llvm//:tablegen",
|
||||
"@local_config_mlir//:TableGen",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
)
|
||||
|
||||
@ -436,12 +437,17 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:TransformUtils",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
)
|
||||
|
||||
@ -464,7 +470,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/core/api",
|
||||
"@local_config_mlir//:IR",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
@ -501,6 +507,7 @@ cc_library(
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/tools/versioning:op_version",
|
||||
"@com_google_absl//absl/base",
|
||||
@ -509,14 +516,14 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"@local_config_mlir//:StandardDialectRegistration",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Translation",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:StandardDialectRegistration",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -525,7 +532,7 @@ tf_cc_binary(
|
||||
name = "flatbuffer_translate",
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
"@local_config_mlir//:MlirTranslateMain",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
],
|
||||
)
|
||||
|
||||
@ -538,7 +545,7 @@ cc_library(
|
||||
"tf_tfl_translate_cl.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -550,7 +557,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -578,9 +585,9 @@ tf_cc_binary(
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -596,10 +603,10 @@ tf_cc_binary(
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Parser",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -622,12 +629,12 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"@local_config_mlir//:Transforms",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
@ -654,15 +661,15 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite/tools/optimize:quantize_weights",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Parser",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Transforms",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:QuantOpsDialectRegistration",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <cstdarg>
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@ -43,24 +44,24 @@ limitations under the License.
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Translation.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Translation.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
|
||||
@ -103,12 +104,26 @@ using llvm::cl::opt;
|
||||
// Commandline flag to enable the control of flatbuffer import.
|
||||
bool use_external_constant;
|
||||
|
||||
// Commandline flag to enable graph pruning.
|
||||
bool experimental_prune_unreachable_nodes_unconditionally;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool, true> use_external_constant_flag(
|
||||
"use-external-constant",
|
||||
llvm::cl::desc("Use external constant during flatbuffer import"),
|
||||
llvm::cl::location(use_external_constant), llvm::cl::init(false));
|
||||
|
||||
// TODO(b/147111261): After the importer supports generic custom ops, we should
|
||||
// change the flag to a more lightwise flag, e.g.
|
||||
// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
|
||||
// the operations.
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
|
||||
"experimental-prune-unreachable-nodes-unconditionally",
|
||||
llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
|
||||
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
|
||||
llvm::cl::init(false));
|
||||
|
||||
namespace {
|
||||
bool IsScalar(const TensorT& tensor) {
|
||||
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
||||
@ -217,7 +232,7 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
|
||||
// min/max stats is just for comments, so ignore it.
|
||||
if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
|
||||
// If the result isn't float and unquantizable, the min/max is ignored.
|
||||
if (!res->getType()
|
||||
if (!res.getType()
|
||||
.cast<mlir::ShapedType>()
|
||||
.getElementType()
|
||||
.isa<mlir::FloatType>()) {
|
||||
@ -255,10 +270,23 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
|
||||
}
|
||||
|
||||
StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
|
||||
// TODO(krzysd) Support custom ops
|
||||
// TODO(b/143872630): Support custom ops
|
||||
if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||
return errors::Unimplemented("unsupported custom operation: ",
|
||||
opcode.custom_code);
|
||||
// Adding some custom op supported on GPU.
|
||||
const absl::string_view custom_name = opcode.custom_code;
|
||||
if (custom_name == "MaxPoolingWithArgmax2D") {
|
||||
return std::string("tfl.max_pooling_with_argmax_2d");
|
||||
}
|
||||
if (custom_name == "Convolution2DTransposeBias") {
|
||||
return std::string("tfl.convolution_2d_transpose_bias");
|
||||
}
|
||||
if (custom_name == "MaxUnpooling2D") {
|
||||
return std::string("tfl.max_unpooling_2d");
|
||||
}
|
||||
// Use an unsupported op name instead of throwing an error here in case the
|
||||
// op is pruned during the import.
|
||||
return std::string(
|
||||
llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str());
|
||||
}
|
||||
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
|
||||
return std::string("tf.If");
|
||||
@ -495,6 +523,13 @@ bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if this is a custom op.
|
||||
bool IsCustomOp(const std::string& op_name) {
|
||||
return op_name == "tfl.max_pooling_with_argmax_2d" ||
|
||||
op_name == "tfl.max_unpooling_2d" ||
|
||||
op_name == "tfl.convolution_2d_transpose_bias";
|
||||
}
|
||||
|
||||
// TODO(krzysd) Handle function calls
|
||||
StatusOr<Operation*> ConvertOp(
|
||||
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
|
||||
@ -557,7 +592,15 @@ StatusOr<Operation*> ConvertOp(
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
||||
if (IsCustomOp(op_name)) {
|
||||
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
|
||||
builder, loc, &attrs);
|
||||
if (!status.ok()) {
|
||||
return emitError(loc, status.ToString()), status;
|
||||
}
|
||||
} else {
|
||||
mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
|
||||
}
|
||||
op_state.addAttributes(attrs);
|
||||
|
||||
// Handle the conversion from subgraph index to functions for If and While
|
||||
@ -619,6 +662,49 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
|
||||
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
||||
}
|
||||
|
||||
// Given a list of output indices, traverses the subgraph and returns the set of
|
||||
// ops that are ancestors of the output tensors.
|
||||
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
|
||||
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
|
||||
// Create a map from tensor index to defining op.
|
||||
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
|
||||
for (const auto& op : subgraph.operators) {
|
||||
for (int32_t output : op->outputs) {
|
||||
defining_op[output] = op.get();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const tflite::OperatorT*> queue;
|
||||
for (int32_t output : output_indices) {
|
||||
if (auto& op = defining_op[output]) {
|
||||
queue.push_back(op);
|
||||
} else {
|
||||
return errors::InvalidArgument("Output tensor doesn't have defining op");
|
||||
}
|
||||
}
|
||||
|
||||
// Traverse the graph towards inputs.
|
||||
absl::flat_hash_set<const tflite::OperatorT*> visited;
|
||||
while (!queue.empty()) {
|
||||
const tflite::OperatorT* op = queue.back();
|
||||
queue.pop_back();
|
||||
if (!visited.insert(op).second) {
|
||||
// The node has already been visited.
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int32_t input : op->inputs) {
|
||||
// Input tensor may not have a defining op in case it is a subgraph input
|
||||
// or a constant tensor.
|
||||
if (auto& op = defining_op[input]) {
|
||||
queue.push_back(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return visited;
|
||||
}
|
||||
|
||||
// Build a FuncOp from a tflite SubGraph
|
||||
// The op_names are a mapping from indexes into the TFLite operators array to
|
||||
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
|
||||
@ -635,7 +721,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
||||
Location base_loc, Builder builder,
|
||||
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
|
||||
bool use_external_constant) {
|
||||
bool use_external_constant,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
llvm::SmallVector<mlir::Type, 2> ret_types;
|
||||
llvm::SmallVector<mlir::Type, 4> input_types;
|
||||
|
||||
@ -731,8 +818,19 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
|
||||
}
|
||||
|
||||
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||
if (experimental_prune_unreachable_nodes_unconditionally) {
|
||||
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
|
||||
PruneSubgraph(subgraph, func_outputs));
|
||||
}
|
||||
|
||||
// Construct MLIR operators from TFLite operators
|
||||
for (auto& op : subgraph.operators) {
|
||||
if (experimental_prune_unreachable_nodes_unconditionally &&
|
||||
!pruned_subgraph_ops.contains(op)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto input_num : op->inputs) {
|
||||
// The operators in a graph are topologically sorted
|
||||
// and so if no previous operation has produced a tensor
|
||||
@ -837,7 +935,8 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
||||
OwningModuleRef tflite::FlatBufferToMlir(
|
||||
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool use_external_constant) {
|
||||
bool use_external_constant,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
auto model_ptr =
|
||||
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
||||
if (nullptr == model_ptr) {
|
||||
@ -892,7 +991,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||
builder, ordered_output_arrays,
|
||||
/*is_entry_point=*/e.index() == 0,
|
||||
/*use_external_constant=*/use_external_constant);
|
||||
/*use_external_constant=*/use_external_constant,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
if (!func_or_error.ok()) {
|
||||
return emitError(base_loc, "could not translate function ")
|
||||
<< subgraph->name,
|
||||
@ -905,9 +1005,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
return OwningModuleRef(module);
|
||||
}
|
||||
|
||||
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
||||
MLIRContext* context,
|
||||
bool use_external_constant) {
|
||||
static OwningModuleRef FlatBufferFileToMlirTrans(
|
||||
llvm::SourceMgr* source_mgr, MLIRContext* context,
|
||||
bool use_external_constant,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
const llvm::MemoryBuffer* input =
|
||||
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
|
||||
std::string error;
|
||||
@ -924,12 +1025,14 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
||||
|
||||
return tflite::FlatBufferToMlir(
|
||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||
context, loc, outputs, use_external_constant);
|
||||
context, loc, outputs, use_external_constant,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
}
|
||||
|
||||
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
||||
"tflite-flatbuffer-to-mlir",
|
||||
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
|
||||
return FlatBufferFileToMlirTrans(&source_mgr, context,
|
||||
use_external_constant);
|
||||
return FlatBufferFileToMlirTrans(
|
||||
&source_mgr, context, use_external_constant,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
});
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
|
||||
namespace tflite {
|
||||
// Converts a TFLite flatbuffer stored in `buffer` to a MLIR module
|
||||
@ -31,11 +31,14 @@ namespace tflite {
|
||||
// on failure, and more specific errors will be emitted via the context.
|
||||
// If `use_external_constant` is true, it will create `tfl.external_const`
|
||||
// instead of `tfl.const`.
|
||||
// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that
|
||||
// are not ancestors of the output nodes will be pruned.
|
||||
mlir::OwningModuleRef FlatBufferToMlir(
|
||||
absl::string_view buffer, mlir::MLIRContext* context,
|
||||
mlir::Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool use_external_constant = false);
|
||||
bool use_external_constant = false,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally = false);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
|
||||
|
@ -17,15 +17,45 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::errors::InvalidArgument;
|
||||
using ::xla::StatusOr;
|
||||
|
||||
StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
|
||||
mlir::Builder builder,
|
||||
mlir::Location loc) {
|
||||
auto padding = tflite::Padding::Padding_VALID;
|
||||
if (pad_params == TfLitePadding::kTfLitePaddingSame) {
|
||||
padding = tflite::Padding_SAME;
|
||||
} else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
|
||||
padding = tflite::Padding_VALID;
|
||||
} else {
|
||||
return InvalidArgument(
|
||||
absl::StrCat("Invalid padding type", std::to_string(pad_params)));
|
||||
}
|
||||
|
||||
const char* option_name = tflite::EnumNamePadding(padding);
|
||||
return builder.getStringAttr(option_name);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO(jpienaar): This is a placeholder. This should be done in more efficient
|
||||
// way when part of the translation of module.
|
||||
static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
|
||||
@ -212,5 +242,44 @@ static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
|
||||
return builder.getStringAttr(option_name);
|
||||
}
|
||||
|
||||
Status mlir::CustomOptionsToAttributes(
|
||||
const std::string& op_name, const std::vector<uint8_t>& custom_options,
|
||||
mlir::Builder builder, mlir::Location loc,
|
||||
llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
|
||||
if (op_name == "tfl.max_pooling_with_argmax_2d" ||
|
||||
op_name == "tfl.max_unpooling_2d") {
|
||||
auto* pool_params =
|
||||
reinterpret_cast<const TfLitePoolParams*>(custom_options.data());
|
||||
TF_ASSIGN_OR_RETURN(auto padding_attribute,
|
||||
GetPaddingAttr(pool_params->padding, builder, loc));
|
||||
attributes->emplace_back(
|
||||
builder.getNamedAttr("padding", padding_attribute));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"stride_h", builder.getI32IntegerAttr(pool_params->stride_height)));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"filter_h", builder.getI32IntegerAttr(pool_params->filter_height)));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"filter_w", builder.getI32IntegerAttr(pool_params->filter_width)));
|
||||
return Status::OK();
|
||||
|
||||
} else if (op_name == "tfl.convolution_2d_transpose_bias") {
|
||||
auto* conv_params = reinterpret_cast<const TfLiteTransposeConvParams*>(
|
||||
custom_options.data());
|
||||
TF_ASSIGN_OR_RETURN(auto padding_attribute,
|
||||
GetPaddingAttr(conv_params->padding, builder, loc));
|
||||
attributes->emplace_back(
|
||||
builder.getNamedAttr("padding", padding_attribute));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"stride_h", builder.getI32IntegerAttr(conv_params->stride_height)));
|
||||
attributes->emplace_back(builder.getNamedAttr(
|
||||
"stride_w", builder.getI32IntegerAttr(conv_params->stride_width)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return InvalidArgument(absl::StrCat("invalid custom op type: ", op_name));
|
||||
}
|
||||
|
||||
// Pull in FlatBuffer writers for TFLite generated using TableGen
|
||||
#include "tensorflow/compiler/mlir/lite/operator_converters.inc"
|
||||
|
@ -26,9 +26,10 @@ limitations under the License.
|
||||
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -45,7 +46,7 @@ llvm::Optional<flatbuffers::Offset<tflite::Operator>> CreateFlatBufferOperator(
|
||||
const std::vector<int32_t> &operands, const std::vector<int32_t> &results,
|
||||
flatbuffers::FlatBufferBuilder *fbb);
|
||||
|
||||
// Populate the array of mlir::NamedAttributes corresponding to the given
|
||||
// Populates the array of mlir::NamedAttributes corresponding to the given
|
||||
// tflite::FlatbufferOptionsUnion.
|
||||
// We use an out parameter per LLVM convention
|
||||
void BuiltinOptionsToAttributes(
|
||||
@ -53,6 +54,15 @@ void BuiltinOptionsToAttributes(
|
||||
// NOLINTNEXTLINE
|
||||
llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes);
|
||||
|
||||
// Populates the array of mlir::NamedAttributes corresponding to the given
|
||||
// custom_options.
|
||||
// We use an out parameter per LLVM convention
|
||||
tensorflow::Status CustomOptionsToAttributes(
|
||||
const std::string &op_name, const std::vector<uint8_t> &custom_options,
|
||||
mlir::Builder builder,
|
||||
// NOLINTNEXTLINE
|
||||
Location loc, llvm::SmallVectorImpl<mlir::NamedAttribute> *attributes);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_
|
||||
|
@ -41,19 +41,19 @@ limitations under the License.
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Translation.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Translation.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
@ -71,6 +71,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||
@ -218,6 +219,13 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>();
|
||||
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
|
||||
}
|
||||
case mlir::TF::TensorFlowTypes::RESOURCE: {
|
||||
// Treat tf.resource values as integer values in flatbuffer.
|
||||
// TODO(b/146131919): Maybe need to have a detailed design for supporting
|
||||
// other resource types beyonds hash table resources and resource
|
||||
// variables.
|
||||
return tflite::TensorType_INT32;
|
||||
}
|
||||
default:
|
||||
// TFLite export fills FLOAT32 for unknown data types. Returning an error
|
||||
// for now for safety and this could be revisited when required.
|
||||
@ -233,17 +241,17 @@ static bool IsConst(Operation* op) {
|
||||
template <typename T>
|
||||
static bool HasValidTFLiteType(Value value, T& error_handler) {
|
||||
// None type is allowed to represent unspecified operands.
|
||||
if (value->getType().isa<NoneType>()) return true;
|
||||
if (value.getType().isa<NoneType>()) return true;
|
||||
|
||||
auto type = value->getType().dyn_cast<TensorType>();
|
||||
auto type = value.getType().dyn_cast<TensorType>();
|
||||
if (!type) {
|
||||
if (auto op = value->getDefiningOp()) {
|
||||
if (auto op = value.getDefiningOp()) {
|
||||
error_handler.emitError()
|
||||
<< '\'' << op << "' should produce value of tensor type instead of "
|
||||
<< value->getType();
|
||||
<< value.getType();
|
||||
return false;
|
||||
}
|
||||
error_handler.emitError("expected tensor type, got ") << value->getType();
|
||||
error_handler.emitError("expected tensor type, got ") << value.getType();
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -282,7 +290,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||
|
||||
for (auto arg : bb.getArguments()) {
|
||||
if (!HasValidTFLiteType(arg, fn))
|
||||
return fn.emitError("invalid TFLite type: ") << arg->getType(), false;
|
||||
return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
|
||||
}
|
||||
|
||||
// Verify that all operations except the terminator have exactly one
|
||||
@ -292,7 +300,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||
|
||||
for (auto result : inst.getResults()) {
|
||||
if (!HasValidTFLiteType(result, inst))
|
||||
return fn.emitError("invalid TFLite type: ") << result->getType(),
|
||||
return fn.emitError("invalid TFLite type: ") << result.getType(),
|
||||
false;
|
||||
}
|
||||
}
|
||||
@ -317,6 +325,48 @@ static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
|
||||
return std::move(status_or_node_def.ValueOrDie());
|
||||
}
|
||||
|
||||
// Converts a mlir padding StringRef to TfLitePadding.
|
||||
// Returns llvm::None if conversion fails.
|
||||
static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
|
||||
llvm::StringRef padding) {
|
||||
const tflite::Padding padding_attr =
|
||||
std::move(llvm::StringSwitch<tflite::Padding>(padding)
|
||||
.Case("SAME", tflite::Padding_SAME)
|
||||
.Case("VALID", tflite::Padding_VALID));
|
||||
if (padding_attr == tflite::Padding_SAME) {
|
||||
return kTfLitePaddingSame;
|
||||
}
|
||||
if (padding_attr == tflite::Padding_VALID) {
|
||||
return kTfLitePaddingValid;
|
||||
}
|
||||
|
||||
return inst->emitOpError() << "Invalid padding attribute: " << padding,
|
||||
llvm::None;
|
||||
}
|
||||
|
||||
// Extracts TfLitePoolParams from a TFL custom op.
|
||||
// Template parameter, TFLOp, should be a TFL custom op containing attributes
|
||||
// generated from TfLitePoolParams.
|
||||
// Returns llvm::None if conversion fails.
|
||||
template <typename TFLOp>
|
||||
static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
|
||||
TFLOp op) {
|
||||
TfLitePoolParams pool_params;
|
||||
pool_params.stride_height = op.stride_h().getSExtValue();
|
||||
pool_params.stride_width = op.stride_w().getSExtValue();
|
||||
pool_params.filter_height = op.filter_h().getSExtValue();
|
||||
pool_params.filter_width = op.filter_w().getSExtValue();
|
||||
const auto padding = GetTflitePadding(inst, op.padding());
|
||||
if (padding) {
|
||||
pool_params.padding = *padding;
|
||||
pool_params.activation = kTfLiteActNone;
|
||||
pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
|
||||
return pool_params;
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
|
||||
@ -375,9 +425,31 @@ class Translator {
|
||||
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Builds custom operators.
|
||||
// Templated on a) data type of custom_option to be stored into flatbuffer,
|
||||
// and b) TFL custom op type.
|
||||
template <typename CustomOptionType, typename TFLOp>
|
||||
BufferOffset<tflite::Operator> BuildCustomOperator(
|
||||
const CustomOptionType& custom_option, const std::string& opcode_name,
|
||||
TFLOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
|
||||
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
Optional<BufferOffset<tflite::Operator>>
|
||||
BuildConvolution2DTransposeBiasOperator(
|
||||
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
|
||||
const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
Optional<BufferOffset<tflite::Operator>> BuildMaxPoolingWithArgMax2DOperator(
|
||||
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
|
||||
const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
Optional<BufferOffset<tflite::Operator>> BuildMaxUnpooling2DOperator(
|
||||
Operation* inst, mlir::TFL::MaxUnpooling2DOp op,
|
||||
const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
|
||||
@ -504,7 +576,7 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
||||
|
||||
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
Value value, const std::string& name, unsigned buffer_idx) {
|
||||
auto type = value->getType().cast<TensorType>();
|
||||
auto type = value.getType().cast<TensorType>();
|
||||
|
||||
// TFLite requires tensor shape only for the inputs and constants.
|
||||
// However, we output all known shapes for better round-tripping
|
||||
@ -516,7 +588,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
|
||||
if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
|
||||
return mlir::emitError(
|
||||
value->getLoc(),
|
||||
value.getLoc(),
|
||||
"result shape dimensions out of 32 bit int type range");
|
||||
|
||||
return mlir::success();
|
||||
@ -528,7 +600,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
|
||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
} else if (auto* inst = value->getDefiningOp()) {
|
||||
} else if (auto* inst = value.getDefiningOp()) {
|
||||
if (IsConst(inst)) {
|
||||
// Const op can have a result of dynamic shaped type (e.g. due to constant
|
||||
// folding), but we can still derive the shape of a constant tensor for
|
||||
@ -571,7 +643,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
// marked as a stateful. If so, set the tensor's is_variable as true
|
||||
// This is v1 ref variable semantics in the TFLite runtime.
|
||||
bool is_variable = false;
|
||||
for (auto& use : value->getUses()) {
|
||||
for (auto& use : value.getUses()) {
|
||||
is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
|
||||
if (is_variable) {
|
||||
break;
|
||||
@ -615,19 +687,72 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
template <typename CustomOptionType, typename TFLOp>
|
||||
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
|
||||
const CustomOptionType& custom_option, const std::string& opcode_name,
|
||||
TFLOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
std::vector<uint8_t> custom_option_vector(sizeof(CustomOptionType));
|
||||
memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType));
|
||||
auto opcode_index =
|
||||
GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM);
|
||||
return tflite::CreateOperator(
|
||||
builder_, opcode_index, builder_.CreateVector(operands),
|
||||
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
|
||||
/*builtin_options=*/0,
|
||||
builder_.CreateVector<uint8_t>(custom_option_vector),
|
||||
tflite::CustomOptionsFormat_FLEXBUFFERS);
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
|
||||
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
float tolerance = op.tolerance().convertToFloat();
|
||||
std::vector<uint8_t> custom_options(sizeof(float));
|
||||
memcpy(custom_options.data(), &tolerance, sizeof(float));
|
||||
auto opcode_index =
|
||||
GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
|
||||
return tflite::CreateOperator(
|
||||
builder_, opcode_index, builder_.CreateVector(operands),
|
||||
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
|
||||
/*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_options),
|
||||
tflite::CustomOptionsFormat_FLEXBUFFERS);
|
||||
return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results);
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Operator>>
|
||||
Translator::BuildConvolution2DTransposeBiasOperator(
|
||||
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
|
||||
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
|
||||
TfLiteTransposeConvParams conv_params;
|
||||
conv_params.stride_height = op.stride_h().getSExtValue();
|
||||
conv_params.stride_width = op.stride_w().getSExtValue();
|
||||
const auto padding = GetTflitePadding(inst, op.padding());
|
||||
if (padding) {
|
||||
conv_params.padding = *padding;
|
||||
return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op,
|
||||
operands, results);
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Operator>>
|
||||
Translator::BuildMaxPoolingWithArgMax2DOperator(
|
||||
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
|
||||
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
|
||||
const auto pool_params = GetTflitePoolParams(inst, op);
|
||||
if (pool_params) {
|
||||
return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op,
|
||||
operands, results);
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Operator>>
|
||||
Translator::BuildMaxUnpooling2DOperator(Operation* inst,
|
||||
mlir::TFL::MaxUnpooling2DOp op,
|
||||
const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
const auto pool_params = GetTflitePoolParams(inst, op);
|
||||
if (pool_params) {
|
||||
return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands,
|
||||
results);
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
||||
@ -769,6 +894,20 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
|
||||
return BuildNumericVerifyOperator(verify_op, operands, results);
|
||||
}
|
||||
if (auto conv_transpose_bias_op =
|
||||
dyn_cast<mlir::TFL::Convolution2DTransposeBiasOp>(inst)) {
|
||||
return BuildConvolution2DTransposeBiasOperator(
|
||||
inst, conv_transpose_bias_op, operands, results);
|
||||
}
|
||||
if (auto max_pooling_with_arg_max_op =
|
||||
dyn_cast<mlir::TFL::MaxPoolingWithArgMax2DOp>(inst)) {
|
||||
return BuildMaxPoolingWithArgMax2DOperator(
|
||||
inst, max_pooling_with_arg_max_op, operands, results);
|
||||
}
|
||||
if (auto max_unpooling_op = dyn_cast<mlir::TFL::MaxUnpooling2DOp>(inst)) {
|
||||
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
|
||||
results);
|
||||
}
|
||||
inst->emitOpError("is not a supported TFLite op");
|
||||
return llvm::None;
|
||||
}
|
||||
@ -923,7 +1062,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
// on failure.
|
||||
auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
|
||||
// NoneType represents optional and may be skipped here.
|
||||
if (value->getType().isa<NoneType>()) {
|
||||
if (value.getType().isa<NoneType>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -936,7 +1075,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
// make the Buffer empty apart from setting the buffer_idx=0 in the Tensor.
|
||||
// This does not seem to affect runtime behavior for RNN/LSTM, but would be
|
||||
// good for reducing memory footprint.
|
||||
if (auto* inst = value->getDefiningOp()) {
|
||||
if (auto* inst = value.getDefiningOp()) {
|
||||
auto buffer_or = BuildBuffer(inst);
|
||||
if (!buffer_or) return false;
|
||||
buffers_.push_back(*buffer_or);
|
||||
@ -976,7 +1115,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
std::vector<int32_t> operands;
|
||||
operands.reserve(inst.getNumOperands());
|
||||
for (auto operand : inst.getOperands()) {
|
||||
if (operand->getType().isa<NoneType>())
|
||||
if (operand.getType().isa<NoneType>())
|
||||
operands.push_back(kTfLiteOptionalTensor);
|
||||
else
|
||||
operands.push_back(tensor_index_map.lookup(operand));
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -25,17 +25,17 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -304,11 +304,11 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
|
||||
void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
|
||||
Value rhs) {
|
||||
auto result_type =
|
||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
||||
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
|
||||
if (!result_type)
|
||||
emitError(result.location)
|
||||
<< "non-broadcastable operands: " << lhs->getType() << " and "
|
||||
<< rhs->getType();
|
||||
<< "non-broadcastable operands: " << lhs.getType() << " and "
|
||||
<< rhs.getType();
|
||||
result.addOperands({lhs, rhs});
|
||||
// Comparison binary ops always return i1 tensor.
|
||||
if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
|
||||
@ -324,12 +324,12 @@ void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
|
||||
Value lhs, Value rhs,
|
||||
StringAttr fused_activation_function) {
|
||||
auto result_type =
|
||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
||||
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
|
||||
|
||||
if (!result_type)
|
||||
emitError(result.location)
|
||||
<< "non-broadcastable operands: " << lhs->getType() << " and "
|
||||
<< rhs->getType();
|
||||
<< "non-broadcastable operands: " << lhs.getType() << " and "
|
||||
<< rhs.getType();
|
||||
|
||||
result.addOperands({lhs, rhs});
|
||||
result.addAttribute("fused_activation_function", fused_activation_function);
|
||||
@ -358,7 +358,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
|
||||
namespace {
|
||||
|
||||
int64_t GetConcatenationOpAxis(ConcatenationOp op) {
|
||||
auto output_type = op.output()->getType().cast<RankedTensorType>();
|
||||
auto output_type = op.output().getType().cast<RankedTensorType>();
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
if (axis < 0) axis += output_type.getRank();
|
||||
return axis;
|
||||
@ -452,7 +452,7 @@ LogicalResult VerifyConcatenationOpTypes(Operation *op,
|
||||
}
|
||||
|
||||
LogicalResult Verify(ConcatenationOp op) {
|
||||
auto output_type = op.output()->getType().dyn_cast<RankedTensorType>();
|
||||
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
// If the output type is unranked, there is nothing else to be verified.
|
||||
if (!output_type) return success();
|
||||
@ -463,7 +463,7 @@ LogicalResult Verify(ConcatenationOp op) {
|
||||
|
||||
SmallVector<TensorType, 4> operand_types;
|
||||
for (Value operand : op.values())
|
||||
operand_types.push_back(operand->getType().cast<TensorType>());
|
||||
operand_types.push_back(operand.getType().cast<TensorType>());
|
||||
|
||||
return VerifyConcatenationOpTypes(op.getOperation(), output_type,
|
||||
operand_types, axis);
|
||||
@ -520,7 +520,7 @@ DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
|
||||
|
||||
OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (fused_activation_function() == "NONE") {
|
||||
if (auto output_type = output()->getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
|
||||
const int64_t axis = GetConcatenationOpAxis(*this);
|
||||
if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
|
||||
return ConstFoldConcatenateOpDense(operands, output_type, axis);
|
||||
@ -530,7 +530,7 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Remove all empty values.
|
||||
SmallVector<Value, 4> non_empty_values;
|
||||
for (Value value : this->values()) {
|
||||
const auto shaped_type = value->getType().cast<ShapedType>();
|
||||
const auto shaped_type = value.getType().cast<ShapedType>();
|
||||
if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
|
||||
continue;
|
||||
}
|
||||
@ -559,8 +559,8 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult Verify(FullyConnectedOp op) {
|
||||
ShapedType input_type = op.input()->getType().cast<ShapedType>();
|
||||
ShapedType filter_type = op.filter()->getType().cast<ShapedType>();
|
||||
ShapedType input_type = op.input().getType().cast<ShapedType>();
|
||||
ShapedType filter_type = op.filter().getType().cast<ShapedType>();
|
||||
if (filter_type.hasRank() && filter_type.getRank() != 2) {
|
||||
return op.emitOpError("expect 2d filter, got ") << filter_type;
|
||||
}
|
||||
@ -582,7 +582,7 @@ LogicalResult Verify(FullyConnectedOp op) {
|
||||
// format.
|
||||
if (op.weights_format() == "DEFAULT") {
|
||||
ShapedType output_type =
|
||||
(*op.output().begin())->getType().cast<ShapedType>();
|
||||
(*op.output().begin()).getType().cast<ShapedType>();
|
||||
if (!output_type.hasStaticShape()) {
|
||||
return mlir::success();
|
||||
}
|
||||
@ -610,8 +610,8 @@ LogicalResult Verify(FullyConnectedOp op) {
|
||||
|
||||
static void BuildGatherOp(Builder *builder, OperationState &result,
|
||||
Value params, Value indices, IntegerAttr axis) {
|
||||
auto params_type = params->getType().cast<TensorType>();
|
||||
auto indices_type = indices->getType().cast<TensorType>();
|
||||
auto params_type = params.getType().cast<TensorType>();
|
||||
auto indices_type = indices.getType().cast<TensorType>();
|
||||
|
||||
// If params/indices is unranked, then output is unranked.
|
||||
if (!params_type.hasRank() || !indices_type.hasRank())
|
||||
@ -705,7 +705,7 @@ static LogicalResult Verify(PackOp op) {
|
||||
return op.emitOpError("input count should match 'values_count' attribute");
|
||||
|
||||
Value operand0 = op.getOperand(0);
|
||||
auto input_type = operand0->getType().cast<ShapedType>();
|
||||
auto input_type = operand0.getType().cast<ShapedType>();
|
||||
|
||||
// Check axis bounds.
|
||||
if (input_type.hasRank()) {
|
||||
@ -718,7 +718,7 @@ static LogicalResult Verify(PackOp op) {
|
||||
// Make sure all inputs have the same shape and element type.
|
||||
// TODO(rahulsp): Simplify once b/135032064 is fixed.
|
||||
for (Value operand : op.getOperands()) {
|
||||
auto other_type = operand->getType().cast<ShapedType>();
|
||||
auto other_type = operand.getType().cast<ShapedType>();
|
||||
if (input_type != other_type)
|
||||
return op.emitOpError("operands should be of the same type. got ")
|
||||
<< input_type << ", " << other_type;
|
||||
@ -732,9 +732,9 @@ static LogicalResult Verify(PackOp op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(PReluOp op) {
|
||||
auto input_type = op.input()->getType().cast<ShapedType>();
|
||||
auto alpha_type = op.alpha()->getType().cast<ShapedType>();
|
||||
auto output_type = op.output()->getType().cast<ShapedType>();
|
||||
auto input_type = op.input().getType().cast<ShapedType>();
|
||||
auto alpha_type = op.alpha().getType().cast<ShapedType>();
|
||||
auto output_type = op.output().getType().cast<ShapedType>();
|
||||
|
||||
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
|
||||
if (input_type.getRank() != alpha_type.getRank() + 1) {
|
||||
@ -783,13 +783,13 @@ struct RemoveAdjacentReshape : public RewritePattern {
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
auto thisOp = cast<ReshapeOp>(op);
|
||||
auto prevOp = thisOp.getOperand(0)->getDefiningOp();
|
||||
auto prevOp = thisOp.getOperand(0).getDefiningOp();
|
||||
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
auto thisOp = cast<ReshapeOp>(op);
|
||||
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0)->getDefiningOp());
|
||||
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
|
||||
|
||||
// Replace
|
||||
// %1 = "tfl.reshape"(%0, %shape0)
|
||||
@ -807,7 +807,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
|
||||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Remove identity reshape with both static result and input shape.
|
||||
auto result_type = getType().cast<ShapedType>();
|
||||
auto input_type = getOperand(0)->getType().cast<ShapedType>();
|
||||
auto input_type = getOperand(0).getType().cast<ShapedType>();
|
||||
if (result_type.hasStaticShape() && result_type == input_type) {
|
||||
return getOperand(0);
|
||||
}
|
||||
@ -865,7 +865,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TFL::PackOp pack_op = cast<TFL::PackOp>(op);
|
||||
Operation *first_input = pack_op.getOperand(0)->getDefiningOp();
|
||||
Operation *first_input = pack_op.getOperand(0).getDefiningOp();
|
||||
if (!first_input) return matchFailure();
|
||||
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
|
||||
if (!input_unpack_op) return matchFailure();
|
||||
@ -905,9 +905,9 @@ void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(SliceOp op) {
|
||||
auto input_type = op.input()->getType().cast<ShapedType>();
|
||||
auto begin_type = op.begin()->getType().cast<ShapedType>();
|
||||
auto size_type = op.size()->getType().cast<ShapedType>();
|
||||
auto input_type = op.input().getType().cast<ShapedType>();
|
||||
auto begin_type = op.begin().getType().cast<ShapedType>();
|
||||
auto size_type = op.size().getType().cast<ShapedType>();
|
||||
if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
|
||||
size_type.hasStaticShape()) {
|
||||
if (input_type.getRank() != begin_type.getNumElements()) {
|
||||
@ -995,7 +995,7 @@ static void BuildTopKOp(Builder *builder, OperationState &result, Value input,
|
||||
// TODO(jpienaar): This should use a helper function.
|
||||
const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
|
||||
|
||||
auto val_type = input->getType().cast<TensorType>();
|
||||
auto val_type = input.getType().cast<TensorType>();
|
||||
// If value is unranked, then so is results.
|
||||
if (!val_type.hasRank())
|
||||
return TFL::TopKV2Op::build(
|
||||
@ -1035,7 +1035,7 @@ struct DropFakeQuant : public RewritePattern {
|
||||
// If all the users of this op have valid "minmax" attributes, it is matched
|
||||
// and can be removed.
|
||||
auto fakeQuantOp = cast<FakeQuantOp>(op);
|
||||
for (auto *operand : fakeQuantOp.getResult()->getUsers())
|
||||
for (auto *operand : fakeQuantOp.getResult().getUsers())
|
||||
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
|
||||
|
||||
return matchSuccess();
|
||||
@ -1102,7 +1102,7 @@ static LogicalResult VerifySplitOpOutputTypes(
|
||||
for (int64_t i = 0; i < num_splits; ++i) {
|
||||
auto expected_output_type = get_expected_output_type(i);
|
||||
Value output = op->getResult(i);
|
||||
auto output_type = output->getType().dyn_cast<RankedTensorType>();
|
||||
auto output_type = output.getType().dyn_cast<RankedTensorType>();
|
||||
if (!output_type || output_type != expected_output_type)
|
||||
return op->emitOpError()
|
||||
<< "output #" << i << " should be " << expected_output_type;
|
||||
@ -1121,7 +1121,7 @@ static LogicalResult Verify(SplitOp op) {
|
||||
if (!split_dim_opt) return success();
|
||||
|
||||
// If 'input' is not a ranked tensor, there are no other checks.
|
||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
||||
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type) return success();
|
||||
|
||||
int64_t split_dim = split_dim_opt.getValue();
|
||||
@ -1157,7 +1157,7 @@ static LogicalResult Verify(SplitVOp op) {
|
||||
if (!split_dim_opt) return success();
|
||||
|
||||
// If 'input' is not a ranked tensor, there are no other checks.
|
||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
||||
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type) return success();
|
||||
|
||||
int64_t split_dim = split_dim_opt.getValue();
|
||||
@ -1177,8 +1177,7 @@ static LogicalResult Verify(SplitVOp op) {
|
||||
return success();
|
||||
|
||||
if (size_splits_attr.getNumElements() != num_splits) {
|
||||
auto size_splits_type =
|
||||
op.size_splits()->getType().cast<RankedTensorType>();
|
||||
auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
|
||||
RankedTensorType expected_size_splits_type =
|
||||
RankedTensorType::get({num_splits}, size_splits_type.getElementType());
|
||||
return op.emitOpError("'size_splits' should be ")
|
||||
@ -1414,7 +1413,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
}
|
||||
|
||||
// Also fold if `input` has a known rank.
|
||||
auto input_type = input()->getType().cast<ShapedType>();
|
||||
auto input_type = input().getType().cast<ShapedType>();
|
||||
// Do not fold if rank is zero because the TFLite converter doesn't
|
||||
// distinguish between unranked input and scalar input due to b/138865275.
|
||||
// TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
|
||||
@ -1445,18 +1444,18 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||
static void BuildSelectV2Op(Builder *builder, OperationState &result,
|
||||
Value cond, Value x, Value y) {
|
||||
auto operand_type =
|
||||
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
|
||||
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
|
||||
|
||||
if (!operand_type)
|
||||
emitError(result.location) << "non-broadcastable operands: " << x->getType()
|
||||
<< " and " << y->getType();
|
||||
emitError(result.location) << "non-broadcastable operands: " << x.getType()
|
||||
<< " and " << y.getType();
|
||||
|
||||
bool has_static_cond_shape = false;
|
||||
bool has_static_operand_shape = false;
|
||||
ArrayRef<int64_t> cond_shape;
|
||||
ArrayRef<int64_t> operand_shape;
|
||||
|
||||
if (auto shaped_type = cond->getType().dyn_cast<ShapedType>()) {
|
||||
if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
|
||||
if (shaped_type.hasStaticShape()) {
|
||||
has_static_cond_shape = true;
|
||||
cond_shape = shaped_type.getShape();
|
||||
@ -1474,12 +1473,12 @@ static void BuildSelectV2Op(Builder *builder, OperationState &result,
|
||||
!OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
|
||||
broadcastedShape)) {
|
||||
emitError(result.location) << "non-broadcastable operands: " << operand_type
|
||||
<< " and " << cond->getType();
|
||||
<< " and " << cond.getType();
|
||||
}
|
||||
|
||||
result.addOperands({cond, x, y});
|
||||
|
||||
auto elementType = x->getType().dyn_cast<ShapedType>().getElementType();
|
||||
auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
|
||||
if (has_static_cond_shape && has_static_operand_shape) {
|
||||
result.types.push_back(
|
||||
RankedTensorType::get(broadcastedShape, elementType));
|
||||
@ -1571,9 +1570,8 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(TransposeConvOp op) {
|
||||
ShapedType output_type = op.output()->getType().cast<ShapedType>();
|
||||
ShapedType output_shape_type =
|
||||
op.output_shape()->getType().cast<ShapedType>();
|
||||
ShapedType output_type = op.output().getType().cast<ShapedType>();
|
||||
ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
|
||||
if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
|
||||
if (output_type.getRank() != output_shape_type.getDimSize(0)) {
|
||||
return op.emitOpError(llvm::formatv(
|
||||
@ -1679,9 +1677,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
}
|
||||
|
||||
static LogicalResult Verify(TransposeOp op) {
|
||||
auto input_type = op.x()->getType().cast<ShapedType>();
|
||||
auto perm_type = op.perm()->getType().cast<ShapedType>();
|
||||
auto output_type = op.y()->getType().cast<ShapedType>();
|
||||
auto input_type = op.x().getType().cast<ShapedType>();
|
||||
auto perm_type = op.perm().getType().cast<ShapedType>();
|
||||
auto output_type = op.y().getType().cast<ShapedType>();
|
||||
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
|
||||
if (perm_type.getNumElements() != input_type.getRank()) {
|
||||
return op.emitOpError(
|
||||
|
@ -18,15 +18,15 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_
|
||||
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Dialect.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
@ -135,7 +135,7 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TFL_OperandIsUnrankedPred<int n> :
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">;
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
|
||||
|
||||
// TODO: Some of these could be generalized and/or moved to more general
|
||||
// location.
|
||||
@ -144,38 +144,38 @@ class TFL_OperandHasRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||
").getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||
|
||||
// Returns true if the n-th operand is ranked and has rank dim.
|
||||
class TFL_OperandHasKnownRank<int n, int dim> : And<[
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() == "
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() == "
|
||||
# dim>]>;
|
||||
|
||||
// True if operand n is ranked and has a rank > dim.
|
||||
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > "
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
|
||||
# dim>]>;
|
||||
|
||||
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
|
||||
TFL_OperandIsRankedAndHasDimPred<n, dim>,
|
||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()"
|
||||
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
|
||||
".getShape()[" # dim # " ] == " # size>]>;
|
||||
|
||||
// Returns true if the n-th operand has unknown rank or at least rank m.
|
||||
class TFL_OperandHasAtleastRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
|
||||
Or<[CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
")->getType().cast<ShapedType>().getRank() >= " # m>]>>;
|
||||
").getType().cast<ShapedType>().getRank() >= " # m>]>>;
|
||||
|
||||
class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
|
||||
PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
|
||||
CPred<"$_op.getOperand(" # x #
|
||||
")->getType().cast<ShapedType>().getRank() == "
|
||||
").getType().cast<ShapedType>().getRank() == "
|
||||
"$_op.getOperand(" # y #
|
||||
")->getType().cast<ShapedType>().getShape()[0]">>;
|
||||
").getType().cast<ShapedType>().getShape()[0]">>;
|
||||
|
||||
class TFL_Operand0DOr1ElementTensor<int x> :
|
||||
PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element",
|
||||
@ -195,7 +195,7 @@ class TFL_OperandHasRankLessThan<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is maximum " # m # "-D",
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
")->getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||
|
||||
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
||||
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
||||
@ -227,7 +227,7 @@ def TFL_BroadcastableBinaryBuilder : OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value lhs, Value rhs",
|
||||
[{
|
||||
auto resultType =
|
||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
||||
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
|
||||
if (!resultType)
|
||||
mlir::emitError(result.location, "non-broadcastable operands");
|
||||
result.addOperands({lhs, rhs});
|
||||
@ -427,6 +427,33 @@ def TFL_TransposeConvOp:
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TFL_Convolution2DTransposeBiasOp :
|
||||
Op<TFL_Dialect, "convolution_2d_transpose_bias", [NoSideEffect]> {
|
||||
let summary = " Transpose convolution with bias operator";
|
||||
|
||||
let description = [{
|
||||
Performs transpose convolution operation on inputs,
|
||||
with the option of adding a bias.
|
||||
Note this is a custom op that is not supported in the standard runtime.
|
||||
|
||||
Inputs:
|
||||
`inputs[0]`: required: the input activation tensor
|
||||
`inputs[1]`: required: the filter weight tensor
|
||||
`inputs[2]`: optional: the bias tensor
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
AnyTensor:$filter,
|
||||
TFL_TensorOfOrNone<[AnyType]>:$bias,
|
||||
TFL_PaddingAttr:$padding,
|
||||
I32Attr:$stride_h,
|
||||
I32Attr:$stride_w
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_AveragePool2DOp:
|
||||
TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Average_pool_2d operator";
|
||||
@ -471,7 +498,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
|
||||
let hasOptions = 1;
|
||||
|
||||
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
|
||||
return getResult()->getType().cast<TensorType>().getElementType().
|
||||
return getResult().getType().cast<TensorType>().getElementType().
|
||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||
tflite::TensorType_INT32;
|
||||
}]>;
|
||||
@ -500,7 +527,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
|
||||
let hasOptions = 1;
|
||||
|
||||
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
|
||||
return getResult()->getType().cast<TensorType>().getElementType().
|
||||
return getResult().getType().cast<TensorType>().getElementType().
|
||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||
tflite::TensorType_INT32;
|
||||
}]>;
|
||||
@ -1181,7 +1208,8 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Greater operator";
|
||||
|
||||
let description = [{
|
||||
@ -1194,6 +1222,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let builders = [TFL_ComparisonBinaryBuilder];
|
||||
|
||||
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||
@ -1260,7 +1290,8 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
|
||||
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_LessOp : TFL_Op<"less", [
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Less operator";
|
||||
|
||||
let description = [{
|
||||
@ -1427,6 +1458,63 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
let customOption = "Pool2DOptions";
|
||||
}
|
||||
|
||||
def TFL_MaxPoolingWithArgMax2DOp :
|
||||
Op<TFL_Dialect, "max_pooling_with_argmax_2d", [NoSideEffect]> {
|
||||
let summary = "Max Pool 2D with argmax op";
|
||||
|
||||
let description = [{
|
||||
Performs max pooling on the input and outputs both max values and indices.
|
||||
Each index is a flatten index in a sub-array of "filter_w" x "filter_h" size
|
||||
Note this is a custom op that is not supported in the standard runtime.
|
||||
|
||||
Inputs:
|
||||
`inputs[0]`: required: the input activation tensor
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
TFL_PaddingAttr:$padding,
|
||||
I32Attr:$stride_w,
|
||||
I32Attr:$stride_h,
|
||||
I32Attr:$filter_w,
|
||||
I32Attr:$filter_h
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
AnyTensor:$value,
|
||||
AnyTensor:$indices
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_MaxUnpooling2DOp :
|
||||
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect]> {
|
||||
let summary = "Max Unpool 2D";
|
||||
|
||||
let description = [{
|
||||
Performs max unpool operation.
|
||||
To some extent this is the reverse operation of max pooling:
|
||||
the elements in the input activation tensor is stored into the position
|
||||
specified by the input indices.
|
||||
Note this is a custom op that is not supported in the standard runtime.
|
||||
|
||||
Inputs:
|
||||
`inputs[0]`: required: the input activation tensor
|
||||
`inputs[1]`: required: the input indices
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
AnyTensor:$indices,
|
||||
TFL_PaddingAttr:$padding,
|
||||
I32Attr:$stride_w,
|
||||
I32Attr:$stride_h,
|
||||
I32Attr:$filter_w,
|
||||
I32Attr:$filter_h
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$outputs);
|
||||
}
|
||||
|
||||
def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
|
||||
@ -1996,7 +2084,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
DerivedTypeAttr out_type = DerivedTypeAttr<[{
|
||||
return getResult()->getType().cast<TensorType>().getElementType();
|
||||
return getResult().getType().cast<TensorType>().getElementType();
|
||||
}]>;
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -2039,7 +2127,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
|
||||
|
||||
Args:
|
||||
tensor: A Tensor. Must be one of the following types:
|
||||
int16, int32, int64, float32 Up to 8-D.
|
||||
uint8, int16, int32, int64, float32, bool Up to 8-D.
|
||||
|
||||
axis: A Tensor. Must be one of the following types: int32, int64.
|
||||
with only 1 element which is the axis index.
|
||||
@ -2048,12 +2136,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
|
||||
|
||||
let arguments = (
|
||||
ins
|
||||
TensorOf<[F32, I16, I32, I64]>:$input,
|
||||
TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input,
|
||||
TensorOf<[I32, I64]>:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I16, I32, I64, I8]>:$output
|
||||
TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
@ -2083,7 +2171,7 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
|
||||
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"Value condition, Value x, Value y",
|
||||
[{
|
||||
auto resultType = x->getType();
|
||||
auto resultType = x.getType();
|
||||
result.addOperands({condition, x, y});
|
||||
result.types.push_back(resultType);
|
||||
}]>];
|
||||
@ -2733,7 +2821,7 @@ in the unique output `y`. In other words:
|
||||
);
|
||||
|
||||
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
|
||||
return getResult(1)->getType().cast<TensorType>().getElementType().
|
||||
return getResult(1).getType().cast<TensorType>().getElementType().
|
||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||
tflite::TensorType_INT32;
|
||||
}]>;
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
@ -30,10 +30,10 @@ limitations under the License.
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Parser.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Parser.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "llvm/TableGen/Main.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Attribute.h" // TF:local_config_mlir
|
||||
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
|
||||
|
||||
using llvm::DefInit;
|
||||
using llvm::dyn_cast;
|
||||
|
@ -28,10 +28,10 @@ cc_library(
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:ViewOpGraph",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
@ -107,9 +107,6 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
|
||||
if (toco_flags.output_format()) {
|
||||
LOG(WARNING) << "Ignored output_format.";
|
||||
}
|
||||
if (toco_flags.default_ranges_min() || toco_flags.default_ranges_max()) {
|
||||
LOG(WARNING) << "Ignored default_ranges_stats.";
|
||||
}
|
||||
if (toco_flags.drop_control_dependency()) {
|
||||
LOG(WARNING) << "Ignored drop_control_dependency.";
|
||||
}
|
||||
@ -242,6 +239,13 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
|
||||
|
||||
// Other flags.
|
||||
if (toco_flags.has_default_ranges_min()) {
|
||||
quant_specs.default_ranges.first = toco_flags.default_ranges_min();
|
||||
}
|
||||
if (toco_flags.has_default_ranges_max()) {
|
||||
quant_specs.default_ranges.second = toco_flags.default_ranges_max();
|
||||
}
|
||||
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
bool emit_custom_ops = toco_flags.allow_custom_ops();
|
||||
|
@ -13,7 +13,7 @@ package(
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["@local_config_mlir//:subpackages"],
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = ["//tensorflow/compiler/mlir/..."],
|
||||
)
|
||||
|
||||
@ -26,8 +26,8 @@ filegroup(
|
||||
name = "quantization_td_files",
|
||||
srcs = [
|
||||
"quantization.td",
|
||||
"@local_config_mlir//:OpBaseTdFiles",
|
||||
"@local_config_mlir//:QuantizationOpsTdFiles",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:QuantizationOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
@ -53,13 +53,13 @@ cc_library(
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -75,11 +75,11 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(fengliuai): remove this dependence.
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
@ -97,7 +97,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -107,8 +107,8 @@ tf_native_cc_binary(
|
||||
"tools/op_quant_spec_getters_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:support",
|
||||
"@llvm//:tablegen",
|
||||
"@local_config_mlir//:TableGen",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
)
|
||||
|
@ -23,18 +23,18 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/AffineExpr.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/AffineMap.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineMap.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
@ -78,8 +78,8 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
|
||||
bool IsQuantizableResult(Operation *op, int index) {
|
||||
if (index < 0 || index >= op->getNumResults()) return false;
|
||||
Value res = op->getResult(index);
|
||||
return res->getType().isa<ShapedType>() &&
|
||||
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
||||
return res.getType().isa<ShapedType>() &&
|
||||
res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
||||
}
|
||||
|
||||
// A method to retrieve the name for the given op.
|
||||
@ -123,7 +123,7 @@ void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
|
||||
IntegerAttr axis) {
|
||||
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
|
||||
layer_stats, axis_stats, axis);
|
||||
res->replaceAllUsesWith(stats_op);
|
||||
res.replaceAllUsesWith(stats_op);
|
||||
stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,7 @@ package(
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["@local_config_mlir//:subpackages"],
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//tensorflow/lite/...",
|
||||
@ -36,9 +36,9 @@ cc_library(
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
)
|
||||
|
||||
@ -53,6 +53,6 @@ tf_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
@ -17,11 +17,11 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
@ -64,6 +65,10 @@ struct QuantizationSpecs {
|
||||
// quantization aware training or calibration, for the remaining tensors.
|
||||
std::vector<std::pair<double, double>> input_ranges;
|
||||
|
||||
// The default ranges can be used when a tensor doesn't have quantization
|
||||
// parameters and couldn't be quantized. Used only for latency tests.
|
||||
std::pair<llvm::Optional<double>, llvm::Optional<double>> default_ranges;
|
||||
|
||||
// A serialized "QuantizationInfo" object to specify value ranges for some of
|
||||
// the tensors with known names.
|
||||
std::string serialized_quant_stats = "";
|
||||
|
@ -23,17 +23,17 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
@ -146,14 +146,14 @@ class QuantizationDriver {
|
||||
|
||||
// Adds all the users of index-th result of op to the work list.
|
||||
void AddUserToList(Operation *op, int index) {
|
||||
for (auto *user : op->getResult(index)->getUsers()) {
|
||||
for (auto *user : op->getResult(index).getUsers()) {
|
||||
work_list_.push_back(user);
|
||||
}
|
||||
}
|
||||
|
||||
// Adds the defining op of index-th operand of op to the work list.
|
||||
void AddOperandToList(Operation *op, int index) {
|
||||
if (auto *inst = op->getOperand(index)->getDefiningOp()) {
|
||||
if (auto *inst = op->getOperand(index).getDefiningOp()) {
|
||||
work_list_.push_back(inst);
|
||||
}
|
||||
}
|
||||
@ -248,7 +248,7 @@ class QuantizationDriver {
|
||||
return;
|
||||
}
|
||||
QuantParams params =
|
||||
quant::QuantizedType::getQuantizedElementType(in->getType());
|
||||
quant::QuantizedType::getQuantizedElementType(in.getType());
|
||||
bool immutable = !EmptyParams(params);
|
||||
int next_state_index = states_.size();
|
||||
states_.push_back({params, immutable});
|
||||
@ -338,7 +338,7 @@ bool QuantizationDriver::IsQuantized(Operation *op) {
|
||||
int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
|
||||
bool as_result) {
|
||||
QuantParams params =
|
||||
quant::QuantizedType::getQuantizedElementType(val->getType());
|
||||
quant::QuantizedType::getQuantizedElementType(val.getType());
|
||||
bool immutable = !EmptyParams(params);
|
||||
int next_state_index = states_.size();
|
||||
states_.push_back({params, immutable});
|
||||
@ -447,13 +447,13 @@ void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
|
||||
}
|
||||
|
||||
void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
|
||||
builder_.setInsertionPointToStart(arg->getOwner());
|
||||
builder_.setInsertionPointToStart(arg.getOwner());
|
||||
QuantizeValue(arg, params, builder_.getUnknownLoc());
|
||||
}
|
||||
|
||||
void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
||||
Location loc) {
|
||||
Type expressed_type = value->getType();
|
||||
Type expressed_type = value.getType();
|
||||
Type new_type = params.castFromExpressedType(expressed_type);
|
||||
// This value isn't an expressed type (float), skip.
|
||||
if (!new_type) return;
|
||||
@ -465,7 +465,7 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
||||
quantize.output());
|
||||
// `original_result` has a use to `quantize`, so this will replace that use
|
||||
// by the result of `dequantize`. Remember to reset that use afterwards
|
||||
value->replaceAllUsesWith(dequantize);
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
}
|
||||
|
||||
@ -475,7 +475,7 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
|
||||
builder_.setInsertionPointAfter(op);
|
||||
Value value = op->getResult(index);
|
||||
if (state->pos == RequantizeState::ON_OUTPUT) {
|
||||
Operation *user = value->getUses().begin().getUser();
|
||||
Operation *user = value.getUses().begin().getUser();
|
||||
if (llvm::isa<TFL::QuantizeOp>(user)) {
|
||||
// The requantize op is inserted between `quantize` and `dequantize` ops.
|
||||
value = user->getResult(0);
|
||||
@ -488,12 +488,12 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
|
||||
void QuantizationDriver::RequantizeArg(BlockArgument arg,
|
||||
RequantizeState *state) {
|
||||
Value value = arg;
|
||||
builder_.setInsertionPointToStart(arg->getOwner());
|
||||
if (value->hasOneUse()) {
|
||||
auto user = value->use_begin().getUser();
|
||||
builder_.setInsertionPointToStart(arg.getOwner());
|
||||
if (value.hasOneUse()) {
|
||||
auto user = value.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
value = q.output();
|
||||
builder_.setInsertionPoint(arg->getOwner(), ++Block::iterator(user));
|
||||
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
|
||||
}
|
||||
}
|
||||
RequantizeValue(value, state, builder_.getUnknownLoc());
|
||||
@ -503,13 +503,13 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
|
||||
Location loc) {
|
||||
Type new_type;
|
||||
if (state->pos == RequantizeState::ON_INPUT) {
|
||||
Type expressed_type = value->getType();
|
||||
Type expressed_type = value.getType();
|
||||
// The value needs to be requantized. A Quantize op will be created to use
|
||||
// it as the operand and replace its uses.
|
||||
new_type = state->params.castFromExpressedType(expressed_type);
|
||||
} else {
|
||||
Type expressed_type =
|
||||
quant::QuantizedType::castToExpressedType(value->getType());
|
||||
quant::QuantizedType::castToExpressedType(value.getType());
|
||||
if (!expressed_type) return;
|
||||
|
||||
// The value needs to be requantized. A Quantize op will be created to use
|
||||
@ -522,7 +522,7 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
|
||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||
auto requantize_op =
|
||||
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
|
||||
value->replaceAllUsesWith(requantize_op);
|
||||
value.replaceAllUsesWith(requantize_op);
|
||||
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
|
||||
}
|
||||
|
||||
@ -603,7 +603,7 @@ void QuantizationDriver::PreprocessConstantOps() {
|
||||
Value value = cst.getResult();
|
||||
SmallVector<std::pair<Operation *, int>, 4> bias_users;
|
||||
bool used_as_weight = false;
|
||||
for (auto &use : value->getUses()) {
|
||||
for (auto &use : value.getUses()) {
|
||||
auto spec = GetQuantSpec(use.getOwner());
|
||||
auto biases = spec->biases_params;
|
||||
Operation *user = use.getOwner();
|
||||
@ -649,8 +649,8 @@ void QuantizationDriver::SetupAllStates() {
|
||||
args_.push_back(arg);
|
||||
Value value = arg;
|
||||
// If the argument is quantized, it should only has one user.
|
||||
if (arg->hasOneUse()) {
|
||||
auto user = value->use_begin().getUser();
|
||||
if (arg.hasOneUse()) {
|
||||
auto user = value.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
value = q.output();
|
||||
}
|
||||
@ -666,7 +666,7 @@ void QuantizationDriver::SetupAllStates() {
|
||||
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
auto operand = op->getOperand(i);
|
||||
if (auto *inst = operand->getDefiningOp()) {
|
||||
if (auto *inst = operand.getDefiningOp()) {
|
||||
// If the operand comes from a tfl.dequantize op, we use the quantized
|
||||
// input of this tfl.dequantize op to set the state.
|
||||
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
|
||||
@ -677,12 +677,12 @@ void QuantizationDriver::SetupAllStates() {
|
||||
}
|
||||
|
||||
for (int res = 0, e = op->getNumResults(); res != e; ++res) {
|
||||
auto result = op->getResult(res);
|
||||
Value result = op->getResult(res);
|
||||
// If the result has been quantized, it should only be used by a
|
||||
// tfl.quantize op. For this case, we uses the quantized result to
|
||||
// create the state and mark it immutable.
|
||||
if (result->hasOneUse()) {
|
||||
auto user = result->use_begin().getUser();
|
||||
if (result.hasOneUse()) {
|
||||
auto user = result.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
result = q.output();
|
||||
}
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace quant {
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
|
||||
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
@ -70,7 +70,7 @@ class FixedResultUniformScale {
|
||||
QuantizedType GetResultQuantizedType(int index) {
|
||||
auto op = this->getOperation();
|
||||
auto result_type =
|
||||
op->getResult(index)->getType().template cast<TensorType>();
|
||||
op->getResult(index).getType().template cast<TensorType>();
|
||||
Builder builder(op->getContext());
|
||||
IntegerType storage_type = builder.getIntegerType(BitWidth);
|
||||
const double scale = static_cast<double>(ScaleMantissa) *
|
||||
|
@ -21,15 +21,15 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantizeUtils.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -367,7 +367,7 @@ ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
|
||||
static bool PreferResultScale(Operation* op) {
|
||||
int float_operands = 0;
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (auto operand_type = operand->getType().dyn_cast<ShapedType>()) {
|
||||
if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
|
||||
if (operand_type.getElementType().isa<FloatType>()) {
|
||||
if (float_operands++ > 1) return true;
|
||||
}
|
||||
@ -400,22 +400,22 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
quant::StatisticsOp stats_op = all_stats_ops.back();
|
||||
all_stats_ops.pop_back();
|
||||
|
||||
if (auto def = stats_op.arg()->getDefiningOp()) {
|
||||
if (auto def = stats_op.arg().getDefiningOp()) {
|
||||
if (IsStatsRedundant(def, op_quant_spec_getter)) {
|
||||
redundant_stats_ops.insert(stats_op);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto user : stats_op.getResult()->getUsers()) {
|
||||
for (auto user : stats_op.getResult().getUsers()) {
|
||||
// We don't propagate this parameter down if it has multiple operands.
|
||||
// We want to use the result parameter scales instead.
|
||||
|
||||
if (user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
||||
!PreferResultScale(user)) {
|
||||
for (Value res : user->getResults()) {
|
||||
if (res->hasOneUse()) {
|
||||
if (res.hasOneUse()) {
|
||||
if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
|
||||
*res->getUsers().begin())) {
|
||||
*res.getUsers().begin())) {
|
||||
// quantization parameters can be propagated to next_stats
|
||||
redundant_stats_ops.insert(next_stats);
|
||||
// add next_stats to the work list so propagation can
|
||||
@ -440,12 +440,12 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
quant::StatisticsOp stats_op = all_stats_ops.back();
|
||||
all_stats_ops.pop_back();
|
||||
|
||||
if (auto def = stats_op.arg()->getDefiningOp()) {
|
||||
if (auto def = stats_op.arg().getDefiningOp()) {
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
||||
PreferResultScale(def)) {
|
||||
for (auto input : def->getOperands()) {
|
||||
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||
input->getDefiningOp())) {
|
||||
input.getDefiningOp())) {
|
||||
redundant_stats_ops.insert(next_stats);
|
||||
all_stats_ops.push_back(next_stats);
|
||||
}
|
||||
@ -458,7 +458,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
for (auto it : redundant_stats_ops) {
|
||||
if (!llvm::isa<quant::StatisticsOp>(it)) return true;
|
||||
auto stats_op = llvm::cast<quant::StatisticsOp>(it);
|
||||
stats_op.getResult()->replaceAllUsesWith(stats_op.arg());
|
||||
stats_op.getResult().replaceAllUsesWith(stats_op.arg());
|
||||
stats_op.erase();
|
||||
}
|
||||
|
||||
|
@ -23,18 +23,18 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -116,7 +116,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
|
||||
TypeAttr::get(result_type));
|
||||
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
|
||||
op.getResult()->replaceAllUsesWith(dq);
|
||||
op.getResult().replaceAllUsesWith(dq);
|
||||
q.getOperation()->replaceUsesOfWith(dq, op.arg());
|
||||
op.erase();
|
||||
|
||||
@ -162,7 +162,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
return matchFailure();
|
||||
}
|
||||
Value quantized_value = op->getResult(0);
|
||||
for (Operation* quantized_op : quantized_value->getUsers()) {
|
||||
for (Operation* quantized_op : quantized_value.getUsers()) {
|
||||
// If it is requantize op, we shouldn't rewrite this op.
|
||||
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
|
||||
return matchFailure();
|
||||
@ -179,14 +179,14 @@ struct QuantizationPattern : public RewritePattern {
|
||||
SmallVector<Value, 4> inputs;
|
||||
inputs.reserve(quantized_op->getNumOperands());
|
||||
for (auto operand : quantized_op->getOperands()) {
|
||||
Type operand_type = operand->getType();
|
||||
Type operand_type = operand.getType();
|
||||
if (operand_type.isa<NoneType>()) {
|
||||
inputs.push_back(operand);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ele_type = operand->getType().cast<TensorType>().getElementType();
|
||||
if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) {
|
||||
auto ele_type = operand.getType().cast<TensorType>().getElementType();
|
||||
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
|
||||
inputs.push_back(op_inst.input());
|
||||
} else if (ele_type.isa<IntegerType>()) {
|
||||
// If the operand is an integer tensor, then it doesn't require the
|
||||
@ -207,7 +207,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
for (auto enumerated_result :
|
||||
llvm::enumerate(quantized_op->getResults())) {
|
||||
Value result = enumerated_result.value();
|
||||
Type result_type = result->getType();
|
||||
Type result_type = result.getType();
|
||||
// Add this to the test coverage once we create test ops with none type
|
||||
// results.
|
||||
if (result_type.isa<NoneType>()) {
|
||||
@ -216,20 +216,20 @@ struct QuantizationPattern : public RewritePattern {
|
||||
continue;
|
||||
}
|
||||
Type result_ele_type =
|
||||
result->getType().cast<TensorType>().getElementType();
|
||||
result.getType().cast<TensorType>().getElementType();
|
||||
// If the user is the Quantize op, it must be the only user.
|
||||
if (result->hasOneUse() && llvm::isa<Q>(*result->user_begin())) {
|
||||
auto user = llvm::cast<Q>(*result->user_begin());
|
||||
if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
|
||||
auto user = llvm::cast<Q>(*result.user_begin());
|
||||
outputs_replaced.insert({user.output(), enumerated_result.index()});
|
||||
output_types.push_back(user.getType());
|
||||
} else if (result_ele_type.template isa<IntegerType>()) {
|
||||
// If the result is an integer tensor, then it doesn't require the
|
||||
// D op in the pattern.
|
||||
outputs_replaced.insert({result, enumerated_result.index()});
|
||||
output_types.push_back(result->getType());
|
||||
output_types.push_back(result.getType());
|
||||
} else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) {
|
||||
outputs_replaced.insert({result, enumerated_result.index()});
|
||||
output_types.push_back(result->getType());
|
||||
output_types.push_back(result.getType());
|
||||
} else {
|
||||
return matchFailure();
|
||||
}
|
||||
@ -241,7 +241,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
output_types, quantized_op->getAttrs());
|
||||
Operation* new_op = rewriter.createOperation(new_state);
|
||||
for (auto output : outputs_replaced) {
|
||||
output.getFirst()->replaceAllUsesWith(
|
||||
output.getFirst().replaceAllUsesWith(
|
||||
new_op->getResult(output.getSecond()));
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
// For constant operands, the floating-point constant is duplicated in
|
||||
// case it is quantized.
|
||||
for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
|
||||
auto def = new_op->getOperand(i)->getDefiningOp();
|
||||
auto def = new_op->getOperand(i).getDefiningOp();
|
||||
if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
|
||||
DenseFPElementsAttr attr;
|
||||
if (!matchPattern(q.input(), m_Constant(&attr))) {
|
||||
@ -265,7 +265,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
|
||||
for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
|
||||
if (!quantized_op->getResult(i)
|
||||
->getType()
|
||||
.getType()
|
||||
.cast<ShapedType>()
|
||||
.getElementType()
|
||||
.isa<FloatType>()) {
|
||||
@ -283,13 +283,13 @@ struct QuantizationPattern : public RewritePattern {
|
||||
// Find the Dequantize/Dequantize users of the new op results, and
|
||||
// replace the usage. Then all the floating-point ops are connected.
|
||||
// N.B. the return op will use this floating-point result.
|
||||
for (auto user : new_op->getResult(i)->getUsers()) {
|
||||
for (auto user : new_op->getResult(i).getUsers()) {
|
||||
// Skip the Requantize op, and we know it has a single user.
|
||||
if (llvm::isa<Q>(user)) {
|
||||
user = *user->getResult(0)->getUsers().begin();
|
||||
user = *user->getResult(0).getUsers().begin();
|
||||
}
|
||||
if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
|
||||
dequantize.getResult()->replaceAllUsesWith(
|
||||
dequantize.getResult().replaceAllUsesWith(
|
||||
quantized_op->getResult(i));
|
||||
}
|
||||
}
|
||||
@ -316,7 +316,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
|
||||
PatternMatchResult matchAndRewrite(Q op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Type output_type = op.output()->getType();
|
||||
Type output_type = op.output().getType();
|
||||
auto qtype = QType::getQuantizedElementType(output_type);
|
||||
if (!qtype || qtype.isSigned()) return this->matchFailure();
|
||||
|
||||
|
@ -4,7 +4,7 @@ package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
@ -14,6 +14,6 @@ filegroup(
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "llvm/TableGen/Main.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Operator.h" // TF:local_config_mlir
|
||||
#include "mlir/TableGen/Operator.h" // TF:llvm-project
|
||||
|
||||
using llvm::LessRecord;
|
||||
using llvm::raw_ostream;
|
||||
|
@ -4,7 +4,7 @@ package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
@ -14,6 +14,6 @@ filegroup(
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
@ -7,7 +7,7 @@ glob_lit_tests(
|
||||
":debug_info_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"pbtxt",
|
||||
# TODO(fengliuai): reenable these tests after the fused loc is
|
||||
@ -33,8 +33,8 @@ filegroup(
|
||||
":saved_model_error",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:not",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,89 @@
|
||||
// RUN: tf-opt %s --tfl-default-quant --tfl-quantize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: hardcode_all
|
||||
func @hardcode_all(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// Quantized tfl.add
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: hardcode_input
|
||||
func @hardcode_input(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
|
||||
%1 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
|
||||
%4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||
return %4 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>}
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[q1]], %[[q0]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: hardcode_input_deq
|
||||
func @hardcode_input_deq(%arg0: tensor<2x2x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||
%1 = "tfl.dequantize"(%arg0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0>>) -> tensor<2x2xf32>
|
||||
%4 = "tfl.add"(%1, %arg1) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||
return %4 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%arg0, %[[q]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: hardcode_output
|
||||
func @hardcode_output(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>
|
||||
%1 = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>}: (tensor<2x1xf32>) -> tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>
|
||||
%2 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x2xf32>
|
||||
%3 = "tfl.dequantize"(%1) : (tensor<2x1x!quant.uniform<u8:f32, 1.0:128>>) -> tensor<2x1xf32>
|
||||
%4 = "tfl.add"(%2, %3) {fused_activation_function="NONE"}: (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||
return %4 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<2x1x!quant.uniform<u8:f32, 1.000000e+00:128>>}
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[q0]], %[[q1]]) {fused_activation_function = "NONE"} : (tensor<2x2x!quant.uniform<u8:f32, 1.000000e+00:128>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[add]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_2d_add
|
||||
func @test_conv_2d_add(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>, %arg2: tensor<32x!quant.uniform<i32:f32, 1.0>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>> {
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x224x224x3xf32>
|
||||
%1 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
|
||||
%2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<i32:f32, 1.0>>) -> tensor<32xf32>
|
||||
%3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%4 = "tfl.pseudo_qconst"() {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>, value = dense<1> : tensor<1x112x112x32xi8>} : () -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
|
||||
%5 = "tfl.dequantize"(%4) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>) -> tensor<1x112x112x32xf32>
|
||||
%6 = "tfl.add"(%3, %5) {fused_activation_function="NONE"}: (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%7 = "tfl.quantize"(%6) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
|
||||
return %7 : tensor<1x112x112x32x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %arg1, %arg2)
|
||||
// CHECK-SAME: -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"()
|
||||
// CHECK: %[[add:.*]] = "tfl.add"(%[[conv]], %[[cst]])
|
||||
// CHECK-SAME: -> tensor<1x112x112x32x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: return %[[add]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_2d_activation_and_bias
|
||||
func @test_conv_2d_activation_and_bias(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>, %arg2: tensor<32xf32>) -> tensor<1x112x112x32xf32> {
|
||||
%0 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.0>>) -> tensor<32x3x3x3xf32>
|
||||
%1 = "tfl.conv_2d"(%arg0, %0, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
return %1 : tensor<1x112x112x32xf32>
|
||||
|
||||
// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg2) {qtype = tensor<32x!quant.uniform<i32:f32, 0.0078431372549019607>>}
|
||||
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 0.0078431372549019607:128>>}
|
||||
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%[[q1]], %arg1, %[[q0]])
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[conv]]) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
@ -7,7 +7,7 @@ glob_lit_tests(
|
||||
":quant_stats_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"pbtxt",
|
||||
],
|
||||
@ -20,7 +20,7 @@ filegroup(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,7 @@ glob_lit_tests(
|
||||
":extra_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"mlir",
|
||||
"cc",
|
||||
@ -24,7 +24,7 @@ filegroup(
|
||||
":importer_test_min_max",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
||||
@ -51,7 +51,7 @@ tf_native_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -67,6 +67,6 @@ tf_native_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
@ -11,6 +11,8 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
|
||||
// CHECK: %[[EXP:.*]] = "tfl.exp"
|
||||
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
||||
// tfl.neg should not be pruned
|
||||
// CHECK: %[[NEG:.*]] = "tfl.neg"
|
||||
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
|
||||
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
|
||||
return %5 : tensor<4xf32>
|
||||
|
@ -0,0 +1,19 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -output-arrays=mul,exp,div --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// Confirm graph pruning.
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
|
||||
// CHECK: %[[MUL:.*]] = tfl.mul
|
||||
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
|
||||
// CHECK: %[[DIV:.*]] = tfl.div
|
||||
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
|
||||
// CHECK: %[[EXP:.*]] = "tfl.exp"
|
||||
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
||||
// tfl.neg should be pruned
|
||||
// CHECK-NOT: "tfl.neg"
|
||||
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
|
||||
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
|
||||
return %5 : tensor<4xf32>
|
||||
}
|
@ -4,7 +4,7 @@ licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ filegroup(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:not",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,76 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||
|
||||
|
||||
func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "Convolution2DTransposeBias"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 32, 4, 4, 128 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 32, 42, 128 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "arg2",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 64, 84, 32 ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "tfl.convolution_2d_transpose_bias",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// MLIR-LABEL: func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>)
|
||||
// MLIR-SAME: -> tensor<1x64x84x32xf32>
|
||||
// MLIR: %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2)
|
||||
// MLIR-SAME: {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32}
|
||||
// MLIR-SAME: (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
|
||||
// MLIR-NEXT: return %0 : tensor<1x64x84x32xf32>
|
||||
|
||||
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
|
||||
return %0 : tensor<1x64x84x32xf32>
|
||||
}
|
@ -0,0 +1,39 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
// CHECK: {
|
||||
// CHECK: version: 3,
|
||||
// CHECK: operator_codes: [ {
|
||||
// CHECK: builtin_code: CUSTOM,
|
||||
// CHECK: custom_code: "HashTableV2"
|
||||
// CHECK: } ],
|
||||
// CHECK: subgraphs: [ {
|
||||
// CHECK: tensors: [ {
|
||||
// CHECK: shape: [ ],
|
||||
// CHECK: type: INT32,
|
||||
// CHECK: buffer: 1,
|
||||
// CHECK: name: "tf.HashTableV2",
|
||||
// CHECK: quantization: {
|
||||
// CHECK-EMPTY
|
||||
// CHECK: }
|
||||
// CHECK: } ],
|
||||
// CHECK: inputs: [ ],
|
||||
// CHECK: outputs: [ 0 ],
|
||||
// CHECK: operators: [ {
|
||||
// CHECK: inputs: [ ],
|
||||
// CHECK: outputs: [ 0 ],
|
||||
// CHECK: custom_options:
|
||||
// CHECK: name: "main"
|
||||
// CHECK: } ],
|
||||
// CHECK: description: "MLIR Converted.",
|
||||
// CHECK: buffers: [ {
|
||||
// CHECK-EMPTY
|
||||
// CHECK: }, {
|
||||
// CHECK-EMPTY
|
||||
// CHECK: } ]
|
||||
// CHECK: }
|
||||
|
||||
func @main() -> tensor<*x!tf.resource> {
|
||||
%0 = "tf.HashTableV2"() {container = "" , shared_name= "table", use_node_name_sharing = false, key_dtype = i32, value_dtype = i32 } : () -> tensor<*x!tf.resource>
|
||||
return %0 : tensor<*x!tf.resource>
|
||||
}
|
||||
|
@ -0,0 +1,65 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||
|
||||
func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "MaxPoolingWithArgmax2D"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 1, 64, 64, 32 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d:1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1, 2 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1, 2 ],
|
||||
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// MLIR-LABEL: func @main(%arg0: tensor<1x64x64x32xf32>)
|
||||
// MLIR-SAME: -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
// MLIR: %value, %indices = "tfl.max_pooling_with_argmax_2d"(%arg0)
|
||||
// MLIR-SAME: {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32}
|
||||
// MLIR-SAME: (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
// MLIR-NEXT: return %value, %indices : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
|
||||
|
||||
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
|
||||
|
||||
func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "MaxUnpooling2D"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "tfl.max_unpooling_2d",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// MLIR-LABEL: func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>)
|
||||
// MLIR-SAME: -> tensor<1x8x8x128xf32>
|
||||
// MLIR: %0 = "tfl.max_unpooling_2d"(%arg0, %arg1)
|
||||
// MLIR-SAME: {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32}
|
||||
// MLIR-SAME: (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32>
|
||||
// MLIR-NEXT: return %0 : tensor<1x8x8x128xf32>
|
||||
|
||||
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
|
||||
return %0 : tensor<1x8x8x128xf32>
|
||||
}
|
@ -518,6 +518,20 @@ func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform<i9:f32
|
||||
|
||||
// -----
|
||||
|
||||
func @testMaxPoolingWithArgMax2D(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
|
||||
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testMaxUnpooling2D(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
|
||||
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
|
||||
return %0 : tensor<1x8x8x128xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testLogistic
|
||||
func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
|
||||
^bb0(%arg0: tensor<1x2x3x4x5xbf16>):
|
||||
@ -1942,6 +1956,13 @@ func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %ar
|
||||
|
||||
// -----
|
||||
|
||||
func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
|
||||
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
|
||||
return %0 : tensor<1x64x84x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> {
|
||||
// expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}}
|
||||
%0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32>
|
||||
|
@ -1,4 +1,7 @@
|
||||
// Run optimize pass only and check the results.
|
||||
// RUN: tf-opt %s -tfl-optimize | FileCheck %s
|
||||
// Run optimize pass and then canonicalize pass, and make sure some folding is applied.
|
||||
// RUN: tf-opt %s -tfl-optimize -canonicalize | FileCheck --check-prefix=FOLD %s
|
||||
|
||||
// CHECK-LABEL: fusedConv2dRelu
|
||||
func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
|
||||
@ -302,6 +305,58 @@ func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst
|
||||
// FOLD-LABEL: @FuseFullyConnectedReshapeAddConst
|
||||
func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<3.0> : tensor<40x40xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<40xf32>
|
||||
%shape1 = constant dense<[1, 40, 40]> : tensor<3xi32>
|
||||
%shape2 = constant dense<[40, 40]> : tensor<2xi32>
|
||||
|
||||
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32>
|
||||
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32>
|
||||
%3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
|
||||
return %3 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%[[fc]]
|
||||
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
|
||||
// FOLD: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||
// FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// FOLD: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastable
|
||||
func @NotReorderReshapeAddIfNotBroadcastable(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<2.0> : tensor<40xf32>
|
||||
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x10x4xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40xf32>) -> tensor<40x40xf32>
|
||||
return %2 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
|
||||
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<2.0> : tensor<1x40xf32>
|
||||
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<1x40xf32>) -> tensor<40x40xf32>
|
||||
return %2 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedRelu
|
||||
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
||||
|
@ -15,11 +15,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -43,6 +43,16 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
quant_specs.inference_type != quant_specs.inference_input_type;
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
|
||||
if (quant_specs.default_ranges.first.hasValue() ||
|
||||
quant_specs.default_ranges.second.hasValue()) {
|
||||
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
|
||||
quant_specs.default_ranges.first.getValueOr(0.0),
|
||||
quant_specs.default_ranges.second.getValueOr(0.0)));
|
||||
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
}
|
||||
}
|
||||
|
||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -20,11 +20,11 @@ limitations under the License.
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
@ -103,7 +103,7 @@ static int PrintFunctionResultMapping(const std::string &result,
|
||||
i = 0;
|
||||
for (auto output : *subgraph->outputs()) {
|
||||
print_buffer(*subgraph, i, output, [&](int i) {
|
||||
return terminator ? terminator->getOperand(i)->getLoc() : unknown_loc;
|
||||
return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Parser.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Parser.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
234
tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
Normal file
234
tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc
Normal file
@ -0,0 +1,234 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The Pass to add default quantization parameters for the activations which
|
||||
// don't have quantization information. These default parameters are usually
|
||||
// not from real measurement, so this pass is only for test purpose.
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
// Includs an auto-generated function, which can retrieve the quantization
|
||||
// specification for an TFL operation. The signature of the function is
|
||||
// std::unique_pointer<OpQuantSpec> TFL::GetOpQuantSpec(Operation *)
|
||||
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
|
||||
|
||||
namespace {
|
||||
class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
|
||||
public:
|
||||
explicit DefaultQuantParamsPass(double default_min, double default_max)
|
||||
: default_min_(default_min), default_max_(default_max) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
// Whether the value is used as a bias input of another op. Here we assume
|
||||
// bias is used immediately by the user. This assumption is always correct
|
||||
// after constant folding.
|
||||
bool UsedAsBias(Value value) {
|
||||
for (auto &use : value.getUses()) {
|
||||
auto biases = TFL::GetOpQuantSpec(use.getOwner())->biases_params;
|
||||
if (biases.find(use.getOperandNumber()) != biases.end()) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Uses `quant_params` to quantize `value` and inserting a pair of
|
||||
// tfl.quantize and tfl.dequantize ops for this `value`.
|
||||
void QuantizeValue(OpBuilder builder, Value value,
|
||||
TFL::QuantParams quant_params);
|
||||
|
||||
// If the value hasn't been quantized, the functions adds it to `values`.
|
||||
void AddToWorkListIfUnquantized(Value value, std::vector<Value> *values);
|
||||
|
||||
// Converts the default min/max to the default quantization parameters.
|
||||
TFL::QuantParams GetDefaultQuantParams(Builder builder);
|
||||
|
||||
// Gets the quantization parameters for the bias of an operation by using the
|
||||
// quantization parameters from the non-biases operands.
|
||||
TFL::QuantParams GetQuantParamsForBias(Operation *op, int bias,
|
||||
const std::vector<int> &non_biases,
|
||||
TFL::AccumulatorScaleFunc func);
|
||||
|
||||
double default_min_;
|
||||
double default_max_;
|
||||
TFL::QuantParams default_quant_params_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void DefaultQuantParamsPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
OpBuilder builder(func);
|
||||
|
||||
std::vector<Value> activation_values;
|
||||
std::vector<Value> bias_values;
|
||||
|
||||
// First of all, collect all the values (block arguments and op results) which
|
||||
// are required to be quantized.
|
||||
for (auto arg : func.getBody().begin()->getArguments()) {
|
||||
if (UsedAsBias(arg)) {
|
||||
AddToWorkListIfUnquantized(arg, &bias_values);
|
||||
} else {
|
||||
AddToWorkListIfUnquantized(arg, &activation_values);
|
||||
}
|
||||
}
|
||||
|
||||
func.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator() ||
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
|
||||
return;
|
||||
|
||||
for (auto res : op->getResults()) {
|
||||
if (UsedAsBias(res)) {
|
||||
AddToWorkListIfUnquantized(res, &bias_values);
|
||||
} else {
|
||||
AddToWorkListIfUnquantized(res, &activation_values);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Apply the default quantization parameters for these activation values.
|
||||
TFL::QuantParams default_params = GetDefaultQuantParams(builder);
|
||||
for (Value value : activation_values) {
|
||||
QuantizeValue(builder, value, default_params);
|
||||
}
|
||||
|
||||
// Since all the non-biases operands have quantization parameters now, we
|
||||
// should be able to propagate them to the bias operand.
|
||||
for (Value bias : bias_values) {
|
||||
Operation *op = *bias.user_begin();
|
||||
auto spec = TFL::GetOpQuantSpec(op);
|
||||
for (auto &it : spec->biases_params) {
|
||||
TFL::QuantParams bias_params = GetQuantParamsForBias(
|
||||
op, it.first, it.second.first, it.second.second);
|
||||
if (!bias_params) continue;
|
||||
QuantizeValue(builder, bias, bias_params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DefaultQuantParamsPass::AddToWorkListIfUnquantized(
|
||||
Value value, std::vector<Value> *values) {
|
||||
// If the result isn't with float type, this result is an integer tensor and
|
||||
// doesn't require quantization.
|
||||
auto tensor_type = value.getType().dyn_cast<TensorType>();
|
||||
if (!tensor_type) {
|
||||
// There are none type values.
|
||||
return;
|
||||
}
|
||||
if (!tensor_type.getElementType().isF32()) return;
|
||||
|
||||
// If the result is consumed by a quantize op, it has been quantized.
|
||||
if (value.hasOneUse() &&
|
||||
llvm::isa<TFL::QuantizeOp>(*value.getUsers().begin()))
|
||||
return;
|
||||
|
||||
// Add this result to the list to apply the default value.
|
||||
values->push_back(value);
|
||||
}
|
||||
|
||||
void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value,
|
||||
TFL::QuantParams quant_params) {
|
||||
Type expressed_type = value.getType();
|
||||
Type new_type = quant_params.castFromExpressedType(expressed_type);
|
||||
// This value isn't an expressed type (float), skip.
|
||||
if (!new_type) return;
|
||||
|
||||
Block &block = value.getParentRegion()->front();
|
||||
Operation *op = value.getDefiningOp();
|
||||
if (op) {
|
||||
builder.setInsertionPoint(&block, ++Block::iterator(op));
|
||||
} else {
|
||||
builder.setInsertionPointToStart(&block);
|
||||
}
|
||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||
auto quantize = builder.create<TFL::QuantizeOp>(value.getLoc(), new_type,
|
||||
value, type_attr);
|
||||
auto dequantize = builder.create<TFL::DequantizeOp>(
|
||||
value.getLoc(), expressed_type, quantize.output());
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
|
||||
// `quantize` is using `dequantize` now, so we should set its operand to
|
||||
// `value`.
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
}
|
||||
|
||||
TFL::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
|
||||
Operation *op, int bias, const std::vector<int> &non_biases,
|
||||
TFL::AccumulatorScaleFunc func) {
|
||||
std::vector<quant::QuantizedType> non_bias_types;
|
||||
non_bias_types.reserve(non_biases.size());
|
||||
for (int non_bias : non_biases) {
|
||||
Operation *non_bias_define = op->getOperand(non_bias).getDefiningOp();
|
||||
if (auto dequant = llvm::dyn_cast<TFL::DequantizeOp>(non_bias_define)) {
|
||||
auto non_bias_type = dequant.input().getType().cast<TensorType>();
|
||||
auto non_bias_ele_type =
|
||||
non_bias_type.getElementType().cast<quant::QuantizedType>();
|
||||
non_bias_types.push_back(non_bias_ele_type);
|
||||
} else {
|
||||
// The non-bias hasn't been quantized, let's skip this bias.
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The non-bias hasn't been quantized, let's skip this bias.
|
||||
if (non_bias_types.size() != non_biases.size()) return {};
|
||||
|
||||
return func(non_bias_types);
|
||||
}
|
||||
|
||||
TFL::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
|
||||
Builder builder) {
|
||||
if (!default_quant_params_) {
|
||||
default_quant_params_ = quant::fakeQuantAttrsToType(
|
||||
builder.getUnknownLoc(),
|
||||
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
|
||||
builder.getF32Type());
|
||||
}
|
||||
return default_quant_params_;
|
||||
}
|
||||
|
||||
// Creates an instance of the default quant parameters pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
double default_min, double default_max) {
|
||||
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
|
||||
}
|
||||
|
||||
// Registers this pass with default values, only for test
|
||||
static PassRegistration<DefaultQuantParamsPass> pass(
|
||||
"tfl-default-quant",
|
||||
"Apply quantization with default quantization parameter", [] {
|
||||
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
|
||||
/*default_max=*/1.0);
|
||||
});
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -21,26 +21,26 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
@ -205,7 +205,7 @@ struct OphintCompositeOp {
|
||||
Operation* current_identity_op = operand.ops.begin()->second;
|
||||
Value input = current_identity_op->getOperand(0);
|
||||
RankedTensorType input_type =
|
||||
input->getType().cast<RankedTensorType>();
|
||||
input.getType().cast<RankedTensorType>();
|
||||
// The Reshape will be {1, (original_shape)}
|
||||
SmallVector<int64_t, 4> reshape_op_shape;
|
||||
reshape_op_shape.push_back(1);
|
||||
@ -242,13 +242,13 @@ struct OphintCompositeOp {
|
||||
}
|
||||
// Find the first op that consumes the last value of the aggregated
|
||||
// inputs.
|
||||
Operation* first_use = *(packed_input_consumers.back()->user_begin());
|
||||
Operation* first_use = *(packed_input_consumers.back().user_begin());
|
||||
// The pack reshape will be {N, (original_shape)}
|
||||
SmallVector<int64_t, 4> pack_shape;
|
||||
pack_shape.push_back(pack_input_operands.size());
|
||||
RankedTensorType type = operand.ops.at(0)
|
||||
->getResult(0)
|
||||
->getType()
|
||||
.getType()
|
||||
.cast<RankedTensorType>();
|
||||
for (const auto& dim : type.getShape()) {
|
||||
pack_shape.push_back(dim);
|
||||
@ -290,7 +290,7 @@ struct OphintCompositeOp {
|
||||
const int output_numer = operand.ops.size();
|
||||
Value first_output = operand.ops.at(0)->getOperand(0);
|
||||
RankedTensorType first_output_type =
|
||||
first_output->getType().cast<RankedTensorType>();
|
||||
first_output.getType().cast<RankedTensorType>();
|
||||
// The aggregated output shape will be {N, original_shape}.
|
||||
SmallVector<int64_t, 4> shape;
|
||||
shape.push_back(output_numer);
|
||||
@ -302,10 +302,10 @@ struct OphintCompositeOp {
|
||||
} else if (operand.aggregation == kStrategyLast) {
|
||||
Value last_output =
|
||||
operand.ops.at(operand.ops.size() - 1)->getOperand(0);
|
||||
aggregated_output_types[kv.first] = last_output->getType();
|
||||
aggregated_output_types[kv.first] = last_output.getType();
|
||||
} else {
|
||||
Value first_output = operand.ops.at(0)->getOperand(0);
|
||||
aggregated_output_types[kv.first] = first_output->getType();
|
||||
aggregated_output_types[kv.first] = first_output.getType();
|
||||
}
|
||||
}
|
||||
return aggregated_output_types;
|
||||
@ -329,7 +329,7 @@ struct OphintCompositeOp {
|
||||
Operation* first_output = operand.ops.at(0);
|
||||
Location insert_loc = first_output->getLoc();
|
||||
SmallVector<Type, 4> unpack_output_types(
|
||||
output_number, first_output->getOperand(0)->getType());
|
||||
output_number, first_output->getOperand(0).getType());
|
||||
|
||||
builder->setInsertionPoint(first_output);
|
||||
Operation* unpack_op = builder->create<TFL::UnpackOp>(
|
||||
@ -404,7 +404,7 @@ void PreprocessTopoSortGraph(
|
||||
// should only count as one.
|
||||
llvm::DenseSet<Operation*> input_ops;
|
||||
for (int i = 0; i < op.getNumOperands(); ++i) {
|
||||
Operation* input_op = op.getOperand(i)->getDefiningOp();
|
||||
Operation* input_op = op.getOperand(i).getDefiningOp();
|
||||
if (input_op) input_ops.insert(input_op);
|
||||
}
|
||||
if (input_ops.empty()) {
|
||||
@ -515,7 +515,7 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
|
||||
SmallVector<int, 4> input_indexes;
|
||||
for (const auto& kv : inputs) {
|
||||
Value input = kv.second;
|
||||
input_types.push_back(input->getType());
|
||||
input_types.push_back(input.getType());
|
||||
input_values.push_back(input);
|
||||
input_indexes.push_back(kv.first);
|
||||
}
|
||||
@ -589,7 +589,7 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
|
||||
std::queue<Operation*> ops_queue;
|
||||
for (auto& input_op : input_ops) {
|
||||
for (Value value : input_op->getOperands()) {
|
||||
Operation* op = value->getDefiningOp();
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (op != nullptr) ops_queue.push(op);
|
||||
}
|
||||
}
|
||||
@ -599,7 +599,7 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
|
||||
ops_queue.pop();
|
||||
reachable_ops.insert(current_op);
|
||||
for (Value value : current_op->getOperands()) {
|
||||
Operation* upstream_op = value->getDefiningOp();
|
||||
Operation* upstream_op = value.getDefiningOp();
|
||||
// Not visited, put it into the queue.
|
||||
if (upstream_op != nullptr &&
|
||||
!llvm::is_contained(reachable_ops, upstream_op)) {
|
||||
@ -642,7 +642,7 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
||||
aggregated_inputs, aggregated_output_types, builder, module_op);
|
||||
|
||||
for (const auto& kv : aggregated_inputs) {
|
||||
Operation* op = kv.second->getDefiningOp();
|
||||
Operation* op = kv.second.getDefiningOp();
|
||||
if (op == nullptr) return failure();
|
||||
op->moveBefore(fused_op);
|
||||
}
|
||||
|
@ -15,23 +15,23 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -103,7 +103,7 @@ LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op,
|
||||
Value hidden_state = call_op.getOperand(4);
|
||||
|
||||
// Build Output.
|
||||
auto output_type = call_op.getResult(0)->getType();
|
||||
auto output_type = call_op.getResult(0).getType();
|
||||
|
||||
// Currently, ophinted RNN only supports time_major = True.
|
||||
const bool time_major = true;
|
||||
@ -170,11 +170,11 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
|
||||
for (int i = 0; i < call_op.getNumResults() - 1; ++i) {
|
||||
// This one should not be used.
|
||||
Value unused_output = call_op.getResult(i);
|
||||
if (!unused_output->use_empty()) return failure();
|
||||
if (!unused_output.use_empty()) return failure();
|
||||
}
|
||||
}
|
||||
output_types.push_back(
|
||||
call_op.getResult(call_op.getNumResults() - 1)->getType());
|
||||
call_op.getResult(call_op.getNumResults() - 1).getType());
|
||||
|
||||
// Prepare attributes.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
@ -207,10 +207,10 @@ LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name,
|
||||
composite_func_op, call_op, builder, &fused_op);
|
||||
if (failed(build_fused_op_result)) return build_fused_op_result;
|
||||
Value call_output = call_op.getResult(call_op.getNumResults() - 1);
|
||||
if (call_output->getType() != fused_op->getResult(0)->getType()) {
|
||||
if (call_output.getType() != fused_op->getResult(0).getType()) {
|
||||
return failure();
|
||||
}
|
||||
call_output->replaceAllUsesWith(fused_op->getResult(0));
|
||||
call_output.replaceAllUsesWith(fused_op->getResult(0));
|
||||
} else { // If we support more fused op, we should add the conversion here.
|
||||
return failure();
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
|
||||
// Use the tensor type information from $0 and convert min $1, max $2 and
|
||||
// numBits $3 and narrowRange $4 to a QuantizedType.
|
||||
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
|
||||
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
|
||||
"GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
|
||||
|
||||
// Converts an integer attribute $0 to 32-bit with builder.
|
||||
def convertIntAttrTo32Bit : NativeCodeCall<
|
||||
@ -50,7 +50,7 @@ def ExtractSingleElementAsInteger : NativeCodeCall<
|
||||
"ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">;
|
||||
|
||||
// Checks whether the given operation has static shapes and same shapes of all inputs.
|
||||
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0->getDefiningOp())">;
|
||||
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
|
||||
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
|
||||
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;
|
||||
|
||||
|
@ -28,15 +28,15 @@ limitations under the License.
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -72,8 +72,8 @@ bool HasSameStaticShapes(Operation* op) {
|
||||
int index = 0;
|
||||
ArrayRef<int64_t> shape;
|
||||
for (Value value : values) {
|
||||
auto shaped_type = value->getType().dyn_cast<ShapedType>();
|
||||
if (!shaped_type && !shaped_type.hasStaticShape()) {
|
||||
auto shaped_type = value.getType().dyn_cast<ShapedType>();
|
||||
if (!shaped_type || !shaped_type.hasStaticShape()) {
|
||||
return false;
|
||||
}
|
||||
if (index == 0) {
|
||||
@ -122,7 +122,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
|
||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
||||
|
||||
auto values = tf_concat_op.values();
|
||||
auto output_type = tf_concat_op.output()->getType();
|
||||
auto output_type = tf_concat_op.output().getType();
|
||||
// Extract axis attribute from constant concat_dims tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
|
||||
@ -141,7 +141,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
|
||||
|
||||
auto values = tf_concat_op.values();
|
||||
auto output_type = tf_concat_op.output()->getType();
|
||||
auto output_type = tf_concat_op.output().getType();
|
||||
// Extract axis attribute from constant axis tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
|
||||
@ -167,7 +167,7 @@ PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
|
||||
if (tf_matmul_op.transpose_a()) return matchFailure();
|
||||
if (!tf_matmul_op.transpose_b()) return matchFailure();
|
||||
|
||||
Type output_type = tf_matmul_op.getResult()->getType();
|
||||
Type output_type = tf_matmul_op.getResult().getType();
|
||||
// TODO(jpienaar): Follow up post shuffle discussion.
|
||||
auto no_input = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
|
||||
@ -184,7 +184,7 @@ PatternMatchResult ConvertTFPackOp::matchAndRewrite(
|
||||
auto tf_pack_op = cast<TF::PackOp>(op);
|
||||
|
||||
SmallVector<Value, 4> values(tf_pack_op.values());
|
||||
auto output_type = tf_pack_op.output()->getType();
|
||||
auto output_type = tf_pack_op.output().getType();
|
||||
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
|
||||
// Axis can be negative.
|
||||
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue());
|
||||
@ -201,7 +201,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
auto input = tf_reshape_op.tensor();
|
||||
auto shape = tf_reshape_op.shape();
|
||||
|
||||
ShapedType shape_type = shape->getType().cast<ShapedType>();
|
||||
ShapedType shape_type = shape.getType().cast<ShapedType>();
|
||||
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
|
||||
if (!shape_type.getElementType().isInteger(32)) {
|
||||
auto new_shape = shape_type.getShape();
|
||||
@ -213,7 +213,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
rewriter.getBoolAttr(false))
|
||||
.y();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output()->getType(),
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
|
||||
input, shape);
|
||||
return matchSuccess();
|
||||
}
|
||||
@ -222,7 +222,7 @@ PatternMatchResult ConvertTFSplitOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_split_op = cast<TF::SplitOp>(op);
|
||||
|
||||
auto output_types = functional::map([](Value v) { return v->getType(); },
|
||||
auto output_types = functional::map([](Value v) { return v.getType(); },
|
||||
tf_split_op.output());
|
||||
// Number of splits cannot be negative.
|
||||
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
|
||||
@ -237,7 +237,7 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_splitv_op = cast<TF::SplitVOp>(op);
|
||||
|
||||
auto output_types = functional::map([](Value v) { return v->getType(); },
|
||||
auto output_types = functional::map([](Value v) { return v.getType(); },
|
||||
tf_splitv_op.output());
|
||||
// Number of splits cannot be negative.
|
||||
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
|
||||
@ -254,7 +254,7 @@ Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
|
||||
DenseIntElementsAttr dense_elem_attr;
|
||||
SmallVector<int32_t, 8> padded_val;
|
||||
|
||||
auto ranked_attr_type = attribute->getType().dyn_cast<RankedTensorType>();
|
||||
auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked_attr_type ||
|
||||
!matchPattern(attribute, m_Constant(&dense_elem_attr))) {
|
||||
// If the input attribute is neither ranked type nor constant, we
|
||||
@ -280,14 +280,14 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
|
||||
auto ranked_input_type =
|
||||
tf_strided_slice_op.input()->getType().dyn_cast<RankedTensorType>();
|
||||
tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked_input_type) {
|
||||
// If input is not a ranked tensor, we can't deduce the padding dimensions
|
||||
// from it, so we just do a plain conversion here.
|
||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
||||
op, tf_strided_slice_op.output()->getType(),
|
||||
tf_strided_slice_op.input(), tf_strided_slice_op.begin(),
|
||||
tf_strided_slice_op.end(), tf_strided_slice_op.strides(),
|
||||
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
||||
tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
|
||||
tf_strided_slice_op.strides(),
|
||||
rewriter.getI32IntegerAttr(
|
||||
tf_strided_slice_op.begin_mask().getSExtValue()),
|
||||
rewriter.getI32IntegerAttr(
|
||||
@ -318,7 +318,7 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
Value padded_strides = PadStridedSliceAttributeArray(
|
||||
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
|
||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
||||
op, tf_strided_slice_op.output()->getType(), tf_strided_slice_op.input(),
|
||||
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
||||
padded_begin, padded_end, padded_strides,
|
||||
rewriter.getI32IntegerAttr(begin_mask),
|
||||
rewriter.getI32IntegerAttr(end_mask),
|
||||
@ -336,7 +336,7 @@ PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
|
||||
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
||||
|
||||
auto input = tf_unpack_op.value();
|
||||
auto output_types = functional::map([](Value v) { return v->getType(); },
|
||||
auto output_types = functional::map([](Value v) { return v.getType(); },
|
||||
tf_unpack_op.output());
|
||||
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
|
||||
// Axis can be negative.
|
||||
@ -360,7 +360,7 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
|
||||
if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
|
||||
|
||||
auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
|
||||
auto output_type = tf_matrix_diag_v2_or_v3_op.output()->getType();
|
||||
auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType();
|
||||
|
||||
// Extract k constant tensor and check value = 0.
|
||||
ElementsAttr k;
|
||||
@ -500,7 +500,7 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
|
||||
auto status_or_const_op = CreateConstOpWithSingleValue(
|
||||
&rewriter, op->getLoc(),
|
||||
tf_reciprocal_op.x()->getType().cast<ShapedType>(), 1);
|
||||
tf_reciprocal_op.x().getType().cast<ShapedType>(), 1);
|
||||
if (!status_or_const_op.ok()) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -71,7 +71,7 @@ struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
|
||||
|
||||
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
|
||||
Type expressed_type =
|
||||
lstm.input()->getType().cast<ShapedType>().getElementType();
|
||||
lstm.input().getType().cast<ShapedType>().getElementType();
|
||||
Type int8_storage_type = builder->getIntegerType(8);
|
||||
Type int16_storage_type = builder->getIntegerType(16);
|
||||
auto flag = quant::QuantizationFlags::FlagValue::Signed;
|
||||
@ -88,8 +88,8 @@ void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
|
||||
auto any_int16 = quant::AnyQuantizedType::get(
|
||||
flag, int16_storage_type, expressed_type, int16_min, int16_max);
|
||||
|
||||
int8 = any_int8.castFromExpressedType(lstm.input()->getType());
|
||||
int16 = any_int16.castFromExpressedType(lstm.input()->getType());
|
||||
int8 = any_int8.castFromExpressedType(lstm.input().getType());
|
||||
int16 = any_int16.castFromExpressedType(lstm.input().getType());
|
||||
}
|
||||
|
||||
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in,
|
||||
|
@ -29,28 +29,28 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
@ -196,13 +196,13 @@ struct ConvertTensorListSetItem : public ConversionPattern {
|
||||
// Calculate `index` + 1, which is used to generate the start position for
|
||||
// the second slice op.
|
||||
auto suffix_start =
|
||||
rewriter.create<TF::AddOp>(loc, index->getType(), index,
|
||||
rewriter.create<TF::AddOp>(loc, index.getType(), index,
|
||||
CreateI32SplatConst(loc, &rewriter, {}, 1));
|
||||
|
||||
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
|
||||
loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
|
||||
// Create two slice ops.
|
||||
Type element_type = input->getType().cast<TensorType>().getElementType();
|
||||
Type element_type = input.getType().cast<TensorType>().getElementType();
|
||||
UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
|
||||
Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
|
||||
TF::SliceOp slice1 =
|
||||
@ -225,7 +225,7 @@ struct ConvertTensorListSetItem : public ConversionPattern {
|
||||
|
||||
// Concatenate three parts together to generate the final result.
|
||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||
op, input->getType(), scalar_zero,
|
||||
op, input.getType(), scalar_zero,
|
||||
ArrayRef<Value>({slice1, expanded_item, slice2}));
|
||||
return matchSuccess();
|
||||
}
|
||||
@ -264,7 +264,7 @@ struct ConvertTensorListInitOp : public ConversionPattern {
|
||||
}
|
||||
|
||||
Value element_shape = operands[0];
|
||||
Type shape_dtype = getElementTypeOrSelf(element_shape->getType());
|
||||
Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
|
||||
|
||||
DenseIntElementsAttr dense_elem_attr;
|
||||
if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
||||
@ -297,11 +297,10 @@ struct ConvertTensorListInitOp : public ConversionPattern {
|
||||
new_element_shape_values.push_back(dim_value);
|
||||
}
|
||||
|
||||
auto attr =
|
||||
DenseIntElementsAttr::get(element_shape->getType().cast<ShapedType>(),
|
||||
new_element_shape_values);
|
||||
auto attr = DenseIntElementsAttr::get(
|
||||
element_shape.getType().cast<ShapedType>(), new_element_shape_values);
|
||||
auto new_element_shape = rewriter.create<ConstantOp>(
|
||||
op.getLoc(), element_shape->getType(), attr);
|
||||
op.getLoc(), element_shape.getType(), attr);
|
||||
element_shape = new_element_shape;
|
||||
}
|
||||
|
||||
@ -355,7 +354,7 @@ struct ConvertTensorListReserve
|
||||
Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
|
||||
PatternRewriter *rewriter) const override {
|
||||
Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
|
||||
Type shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
|
||||
Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
|
||||
Value num_elements = operands[1];
|
||||
return rewriter->create<TF::ExpandDimsOp>(
|
||||
op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
|
||||
@ -392,14 +391,14 @@ struct ConvertTensorListPushBack : public ConversionPattern {
|
||||
// Expand the shape of the item so that it will have rank same as the input
|
||||
// tensor and it is compatible for the Concat Op.
|
||||
Type expanded_item_type =
|
||||
PrependLeadingDimIfRanked(1, item->getType(), &rewriter);
|
||||
PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
|
||||
Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
|
||||
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), expanded_item_type, item, scalar_zero);
|
||||
|
||||
Type elem_type = getElementTypeOrSelf(item);
|
||||
auto handle_dtype =
|
||||
getElementTypeOrSelf(push_back_op.output_handle()->getType())
|
||||
getElementTypeOrSelf(push_back_op.output_handle().getType())
|
||||
.cast<TF::VariantType>();
|
||||
Type result_type =
|
||||
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
|
||||
@ -446,7 +445,7 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
// Infer result type of this op based on TF's shape inference result.
|
||||
Type elem_type = getElementTypeOrSelf(input_handle);
|
||||
auto handle_dtype =
|
||||
getElementTypeOrSelf(resize_op.output_handle()->getType())
|
||||
getElementTypeOrSelf(resize_op.output_handle().getType())
|
||||
.cast<TF::VariantType>();
|
||||
Type result_type =
|
||||
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
|
||||
@ -463,8 +462,8 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
auto input_shape = rewriter.create<TF::ShapeOp>(
|
||||
loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
|
||||
|
||||
Type branch_args_type[] = {input_handle->getType(), input_shape.getType(),
|
||||
size_diff.getType(), size->getType()};
|
||||
Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
|
||||
size_diff.getType(), size.getType()};
|
||||
Type branch_result_type[] = {result_type};
|
||||
auto func_type = FunctionType::get(branch_args_type, branch_result_type,
|
||||
rewriter.getContext());
|
||||
@ -524,7 +523,7 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
|
||||
slice_size);
|
||||
auto extended_part = rewriter->create<TF::TensorListReserveOp>(
|
||||
loc, resize_op.output_handle()->getType(), elem_shape, size_diff);
|
||||
loc, resize_op.output_handle().getType(), elem_shape, size_diff);
|
||||
// `ConcatOp` expects non-variant-typed input. Insert a
|
||||
// `TensorListStackOp` here to convert type from variant to non-variant.
|
||||
// Note that we are using the same `result_type` for both the
|
||||
@ -627,7 +626,7 @@ struct ConvertTensorListStack : public ConversionPattern {
|
||||
// trivial Reshape op (that doesn't actually change the input's shape) and
|
||||
// also populate the shape info to the op result. The shape of the
|
||||
// tensorlist is inferred from `num_elements` and `element_shape`.
|
||||
auto ranked_type = element_shape->getType().dyn_cast<RankedTensorType>();
|
||||
auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
|
||||
DenseIntElementsAttr dense_elem_attr;
|
||||
if ((ranked_type && ranked_type.getRank() == 0) ||
|
||||
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
||||
@ -659,7 +658,7 @@ struct ConvertIdentity : public ConversionPattern {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::IdentityOp>(operation);
|
||||
Value input = operands[0];
|
||||
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input->getType(), operands,
|
||||
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
|
||||
op.getAttrs());
|
||||
return matchSuccess();
|
||||
}
|
||||
@ -687,7 +686,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
Type arg_type = func_type.getInput(i);
|
||||
if (getElementTypeOrSelf(arg_type).isa<TF::VariantType>()) {
|
||||
arg_type = UnrankedTensorType::get(
|
||||
getElementTypeOrSelf(op.getOperand(i)->getType()));
|
||||
getElementTypeOrSelf(op.getOperand(i).getType()));
|
||||
}
|
||||
updated_argument_types.push_back(arg_type);
|
||||
}
|
||||
@ -703,7 +702,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
// from the corresponding input operand. This is correct because while
|
||||
// body's inputs and results have the same type.
|
||||
result_type = UnrankedTensorType::get(
|
||||
getElementTypeOrSelf(op.getOperand(i)->getType()));
|
||||
getElementTypeOrSelf(op.getOperand(i).getType()));
|
||||
}
|
||||
updated_result_types.push_back(result_type);
|
||||
}
|
||||
@ -717,7 +716,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
// Change the argument type for the first block.
|
||||
Block &body_first_bb = func.front();
|
||||
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
|
||||
body_first_bb.getArgument(i)->setType(updated_argument_types[i]);
|
||||
body_first_bb.getArgument(i).setType(updated_argument_types[i]);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
@ -735,12 +734,12 @@ struct ConvertWhile : public ConversionPattern {
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
result_types.reserve(op.getNumOperands());
|
||||
for (int i = 0, e = operands.size(); i != e; ++i) {
|
||||
Type result_ty = op.getResult(i)->getType();
|
||||
Type result_ty = op.getResult(i).getType();
|
||||
|
||||
// If we notice the result type is a DT_VARIANT, we change the
|
||||
// corresponding result type to unranked tensor type.
|
||||
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
|
||||
Type element_ty = getElementTypeOrSelf(operands[i]->getType());
|
||||
Type element_ty = getElementTypeOrSelf(operands[i].getType());
|
||||
result_ty = UnrankedTensorType::get(element_ty);
|
||||
}
|
||||
result_types.push_back(result_ty);
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
// This transformation pass takes operations in TensorFlowLite dialect and
|
||||
// optimizes them to resulting operations in TensorFlowLite dialect.
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
@ -30,14 +31,14 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
@ -51,15 +52,15 @@ namespace TFL {
|
||||
namespace {
|
||||
|
||||
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
|
||||
if (sq_op->getType().cast<ShapedType>().getRank() - 1 ==
|
||||
if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
|
||||
*axis.getValues<int>().begin() ||
|
||||
*axis.getValues<int>().begin() == -1) {
|
||||
return true;
|
||||
}
|
||||
if (sq_op->getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
|
||||
if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
|
||||
return false;
|
||||
}
|
||||
auto shape = sq_op->getType().cast<ShapedType>();
|
||||
auto shape = sq_op.getType().cast<ShapedType>();
|
||||
SmallVector<int, 4> elems{axis.getValues<int>().begin(),
|
||||
axis.getValues<int>().end()};
|
||||
for (int i = 0; i < shape.getRank(); ++i) {
|
||||
@ -80,6 +81,18 @@ bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
|
||||
return OpTrait::util::getBroadcastedType(a, b) != Type();
|
||||
}
|
||||
|
||||
// Returns whether if `type1` dimensions are the same as the ending dimensions
|
||||
// of `type2`. This is more restricted than broadcastable.
|
||||
bool IsTailOfShape(Type type1, Type type2) {
|
||||
auto tail_type = type1.dyn_cast<ShapedType>();
|
||||
auto full_type = type2.dyn_cast<ShapedType>();
|
||||
if (!tail_type || !full_type || tail_type.getRank() > full_type.getRank())
|
||||
return false;
|
||||
auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend();
|
||||
auto i2 = full_type.getShape().rbegin();
|
||||
return std::equal(i1, e1, i2);
|
||||
}
|
||||
|
||||
bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
|
||||
bool is_depthwise) {
|
||||
// Make sure the val tensor has shape where all dimensions are 1 except
|
||||
@ -143,7 +156,7 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
|
||||
// Returns shape of a ranked tensor.
|
||||
// Precondition: output_val's is ranked tensor.
|
||||
DenseElementsAttr GetShape(Value output_val) {
|
||||
auto output_type = output_val->getType().cast<RankedTensorType>();
|
||||
auto output_type = output_val.getType().cast<RankedTensorType>();
|
||||
auto shape_vector = output_type.getShape();
|
||||
std::vector<int32_t> shape(shape_vector.size());
|
||||
for (int i = 0; i < shape_vector.size(); ++i) {
|
||||
@ -152,7 +165,7 @@ DenseElementsAttr GetShape(Value output_val) {
|
||||
return mlir::DenseElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int>(shape.size())},
|
||||
mlir::IntegerType::get(32, output_val->getContext())),
|
||||
mlir::IntegerType::get(32, output_val.getContext())),
|
||||
llvm::makeArrayRef(shape));
|
||||
}
|
||||
|
||||
@ -173,13 +186,13 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
||||
|
||||
// Fully Connected.
|
||||
auto fc_op =
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs()->getDefiningOp());
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
|
||||
if (!fc_op) return matchFailure();
|
||||
|
||||
Value filter = fc_op.filter();
|
||||
Value bias = fc_op.bias();
|
||||
ElementsAttr bias_value;
|
||||
const bool is_none_bias = bias->getType().isa<NoneType>();
|
||||
const bool is_none_bias = bias.getType().isa<NoneType>();
|
||||
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
|
||||
return matchFailure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
||||
@ -213,7 +226,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *input = relu_op.getOperand()->getDefiningOp();
|
||||
Operation *input = relu_op.getOperand().getDefiningOp();
|
||||
if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure();
|
||||
auto fully_connected_op = cast<FullyConnectedOp>(input);
|
||||
if (fully_connected_op.fused_activation_function() != "NONE")
|
||||
@ -247,13 +260,13 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
|
||||
// Fully Connected.
|
||||
auto fc_op =
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp());
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
|
||||
if (!fc_op) return matchFailure();
|
||||
Value filter = fc_op.filter();
|
||||
Value bias = fc_op.bias();
|
||||
ElementsAttr cst_tmp;
|
||||
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
|
||||
if (!bias->getType().isa<NoneType>() &&
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&cst_tmp)))
|
||||
return matchFailure();
|
||||
if (fc_op.fused_activation_function().equals("None")) return matchFailure();
|
||||
@ -262,7 +275,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
// filter input. We only support broadcasting the operand along the depth
|
||||
// dimension, when the operand's depth is 1.
|
||||
Value new_const_val = constant_val;
|
||||
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) {
|
||||
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) {
|
||||
auto original_shape = cst.getType().getShape();
|
||||
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
|
||||
original_shape.end());
|
||||
@ -270,7 +283,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
auto new_cst = cst.reshape(RankedTensorType::get(
|
||||
normalized_shape, cst.getType().getElementType()));
|
||||
Type new_type = new_cst.getType();
|
||||
if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) {
|
||||
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
|
||||
return matchFailure();
|
||||
}
|
||||
auto new_op =
|
||||
@ -285,7 +298,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
auto new_filter =
|
||||
rewriter.create<TF::MulOp>(loc, filter, new_const_val).z();
|
||||
// If bias isn't None, it needs to be multiplied as well.
|
||||
if (!bias->getType().isa<NoneType>()) {
|
||||
if (!bias.getType().isa<NoneType>()) {
|
||||
bias = rewriter.create<TF::MulOp>(loc, bias, constant_val).z();
|
||||
}
|
||||
|
||||
@ -311,7 +324,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
PatternMatchResult matchAndRewrite(AffineOpType fc_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Binary op.
|
||||
Operation *binary_op = fc_op.input()->getDefiningOp();
|
||||
Operation *binary_op = fc_op.input().getDefiningOp();
|
||||
if (!binary_op || binary_op->getNumOperands() != 2)
|
||||
return this->matchFailure();
|
||||
// We only handle the cases the RHS is a scalar.
|
||||
@ -330,15 +343,15 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
DenseFPElementsAttr filter_cst, bias_cst;
|
||||
if (!matchPattern(filter, m_Constant(&filter_cst))) {
|
||||
// The filter maybe quantized, then we should set it to the real constant.
|
||||
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter->getDefiningOp());
|
||||
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
|
||||
if (!dq) return this->matchFailure();
|
||||
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input()->getDefiningOp());
|
||||
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
|
||||
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
|
||||
return this->matchFailure();
|
||||
}
|
||||
filter = q.input();
|
||||
}
|
||||
if (!bias->getType().isa<NoneType>() &&
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&bias_cst)))
|
||||
return this->matchFailure();
|
||||
ShapedType filter_type = filter_cst.getType();
|
||||
@ -362,7 +375,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
// The new bias should be a 1-D tensor with length equals to the bias
|
||||
// dimension of the weight.
|
||||
SmallVector<APFloat, 4> new_bias_values;
|
||||
if (bias->getType().isa<NoneType>()) { // none bias, a list of zeros
|
||||
if (bias.getType().isa<NoneType>()) { // none bias, a list of zeros
|
||||
new_bias_values.resize(bias_size, APFloat(0.0));
|
||||
} else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
|
||||
new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
|
||||
@ -401,12 +414,12 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
// We recreate the constant op in case it is shared by the other ops. This
|
||||
// might increase the model size.
|
||||
auto new_filter_op = rewriter.create<ConstOp>(
|
||||
fc_op.getLoc(), filter->getType(), new_filter);
|
||||
fc_op.getLoc(), filter.getType(), new_filter);
|
||||
fc_op.setOperand(0, binary_op->getOperand(0));
|
||||
if (fc_op.filter() != filter) {
|
||||
// This filter goes through quantize and dequantize ops. Then we just
|
||||
// need to update the weight to the quantize op.
|
||||
filter->replaceAllUsesWith(new_filter_op);
|
||||
filter.replaceAllUsesWith(new_filter_op);
|
||||
} else {
|
||||
// This filter doesn't go through quantize and dequantize ops, Then
|
||||
// we update the weight of the affine op directly.
|
||||
|
@ -17,15 +17,15 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -55,7 +55,7 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
|
||||
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0->hasOneUse()">>;
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
// If we see a binary op (add, sub) op adding a constant value to a convolution
|
||||
// op with constant bias, we can fuse the binary op into the convolution op by
|
||||
@ -161,7 +161,7 @@ def EqualOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
// Checks if the operand has rank == n
|
||||
class OperandHasRank<int n> : Constraint<
|
||||
CPred<"$0->getType().cast<ShapedType>().getRank() == " # n>>;
|
||||
CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>;
|
||||
|
||||
// Matching HardSwish
|
||||
def : Pat<
|
||||
@ -255,8 +255,16 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
|
||||
foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
|
||||
in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def AreBroadcastableTypes : Constraint<CPred<
|
||||
"TFL::IsBroadcastableElementsAttrAndType($0->getType(), $1->getType())">>;
|
||||
"TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>;
|
||||
|
||||
def IsTailOfShape : Constraint<CPred<
|
||||
"TFL::IsTailOfShape($0->getType(), $1->getType())">>;
|
||||
|
||||
def HaveSameType : Constraint<CPred<"$0->getType(), $1->getType()">>;
|
||||
|
||||
// Pattern for skipping Tile if it is mainly for broadcasting and the
|
||||
// Op is already supporting broadcasting.
|
||||
@ -272,13 +280,72 @@ multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
|
||||
[(AreBroadcastableTypes $operand, $input)]>;
|
||||
}
|
||||
|
||||
foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
|
||||
in defm : FuseTileBroadcastIntoFollowingBinary<BroadcastingOp>;
|
||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in {
|
||||
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1])>;
|
||||
}
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
|
||||
defm : FuseTileBroadcastIntoFollowingBinary<BinaryOp>;
|
||||
|
||||
// Instantiated FusedBinary patterns for the from-to pairs of ops.
|
||||
defm : FusedBinaryActivationFuncOpPat<BinaryOp>;
|
||||
|
||||
// Move binary op before reshape: reshape -> binary => binary -> reshape.
|
||||
// This is valid only when the binary operand is constant and the shape is the
|
||||
// tail of the other operand and the intermediate result isn't used by other
|
||||
// ops.
|
||||
// $rhs is required to be the tail shape of $lhs, so after transformation the
|
||||
// shape of the binary op result is valid. For example, assume the shapes of
|
||||
// $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
|
||||
// transformation, the shape of the binary op result is [40x1600], which
|
||||
// couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
|
||||
// make sure $rhs is the tail shape of $lhs.
|
||||
def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs $a), TFL_AF_None),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
|
||||
// The broadcasting of "BinaryOp" only happens in the lower
|
||||
// dimensions, and the higher dimensions are same.
|
||||
[(IsTailOfShape $rhs, $lhs),
|
||||
(HasOneUse $lhs),
|
||||
// the two operands of the binary op is broadcastable
|
||||
(AreBroadcastableTypes $rhs, $input)]>;
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp,
|
||||
TFL_GreaterEqualOp] in {
|
||||
// Move binary op before reshape: reshape -> binary => binary -> reshape.
|
||||
// This is valid only when the binary operand is constant and the shape is the
|
||||
// tail of the other operand and the intermediate result isn't used by other
|
||||
// ops.
|
||||
// $rhs is required to be the tail shape of $lhs, so after transformation the
|
||||
// shape of the binary op result is valid. For example, assume the shapes of
|
||||
// $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
|
||||
// transformation, the shape of the binary op result is [40x1600], which
|
||||
// couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
|
||||
// make sure $rhs is the tail shape of $lhs.
|
||||
def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs $a)),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
|
||||
// The broadcasting of "BinaryOp" only happens in the lower
|
||||
// dimensions, and the higher dimensions are same.
|
||||
[(IsTailOfShape $rhs, $lhs),
|
||||
(HasOneUse $lhs),
|
||||
// the two operands of the binary op is broadcastable
|
||||
(AreBroadcastableTypes $rhs, $input)]>;
|
||||
}
|
||||
|
||||
// Returns shape of a ranked tensor.
|
||||
// if called without a ranked tensor it will fail.
|
||||
def GetShape: NativeCodeCall<"GetShape($0)">;
|
||||
|
||||
// Convert squeeze to reshape
|
||||
def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
|
||||
(TFL_ReshapeOp $input,
|
||||
(ConstantOp (GetShape $squeeze_op))),
|
||||
@ -300,21 +367,6 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
||||
(TFL_Relu1Op $input),
|
||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||
|
||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in {
|
||||
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1])>;
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiated FusedBinary patterns for the from-to pairs of ops.
|
||||
foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
|
||||
TFL_MulOp, TFL_SubOp] in
|
||||
defm : FusedBinaryActivationFuncOpPat<BinaryOps>;
|
||||
|
||||
// The constant folding in this pass might produce constant in the tf dialect.
|
||||
// This rule is to legalize these constant to the tfl dialect.
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
@ -73,6 +73,10 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass();
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass();
|
||||
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass();
|
||||
|
||||
// Creates an instance pass to add default quantization parameters.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
double default_min, double default_max);
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
// This transformation pass applies some clean up steps after quantization.
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -71,29 +71,29 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
|
||||
|
||||
auto remove_quantize_op = [&](QuantizeOp quantize_op) {
|
||||
auto quantize_output = quantize_op.output();
|
||||
auto quantize_type = quantize_output->getType();
|
||||
auto quantize_type = quantize_output.getType();
|
||||
input_types.push_back(quantize_type);
|
||||
auto new_arg = bb.addArgument(quantize_type);
|
||||
quantize_output->replaceAllUsesWith(new_arg);
|
||||
quantize_output.replaceAllUsesWith(new_arg);
|
||||
quantize_op.erase();
|
||||
arg->dropAllUses();
|
||||
arg.dropAllUses();
|
||||
bb.eraseArgument(0);
|
||||
};
|
||||
|
||||
// This is looking for a pattern: arg -> tfl.quantize
|
||||
if (arg->hasOneUse() && llvm::isa<QuantizeOp>(*arg->user_begin())) {
|
||||
auto quantize_op = llvm::cast<QuantizeOp>(*arg->user_begin());
|
||||
if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
|
||||
auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
|
||||
remove_quantize_op(quantize_op);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Make a copy of current argument and append it to the end of the list if
|
||||
// the pattern isn't found.
|
||||
Type arg_type = arg->getType();
|
||||
Type arg_type = arg.getType();
|
||||
input_types.push_back(arg_type);
|
||||
auto new_arg = bb.addArgument(arg_type);
|
||||
arg->replaceAllUsesWith(new_arg);
|
||||
arg->dropAllUses();
|
||||
arg.replaceAllUsesWith(new_arg);
|
||||
arg.dropAllUses();
|
||||
bb.eraseArgument(0);
|
||||
}
|
||||
|
||||
@ -103,15 +103,15 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
|
||||
output_types.reserve(num_return_operands);
|
||||
for (int i = 0; i != num_return_operands; ++i) {
|
||||
auto returned_value = terminator->getOperand(i);
|
||||
Operation* returned_op = returned_value->getDefiningOp();
|
||||
Operation* returned_op = returned_value.getDefiningOp();
|
||||
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
|
||||
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
|
||||
Value dequantized_result = dequantize_op.input();
|
||||
output_types.push_back(dequantized_result->getType());
|
||||
output_types.push_back(dequantized_result.getType());
|
||||
terminator->setOperand(i, dequantized_result);
|
||||
returned_op->erase();
|
||||
} else {
|
||||
output_types.push_back(returned_value->getType());
|
||||
output_types.push_back(returned_value.getType());
|
||||
}
|
||||
}
|
||||
auto new_func_type = builder.getFunctionType(input_types, output_types);
|
||||
|
@ -22,19 +22,19 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||
|
@ -135,10 +135,10 @@ def : Pat<(TF_ReshapeOp
|
||||
// Casts result type of $1 to a quantized type by using the quantization
|
||||
// parameters from the type in $0.
|
||||
class UpdateShapeWithAxis<int i> : NativeCodeCall<
|
||||
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1->getType(), " # i # ")">;
|
||||
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">;
|
||||
|
||||
class UsedBy<string op> : Constraint<
|
||||
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0->getUsers().begin())">>;
|
||||
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0.getUsers().begin())">>;
|
||||
|
||||
// When the op is passing-through, the output types of the quantized ops need
|
||||
// to be updated as well. Since the quantize op manages its own type by the
|
||||
|
@ -21,10 +21,10 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
@ -153,7 +153,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
params);
|
||||
auto dq_op =
|
||||
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
|
||||
arg->replaceAllUsesWith(dq_op.output());
|
||||
arg.replaceAllUsesWith(dq_op.output());
|
||||
q_op.setOperand(arg);
|
||||
}
|
||||
}
|
||||
@ -161,8 +161,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
|
||||
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
|
||||
BlockArgument arg = func.getArgument(i);
|
||||
auto* arg_block = arg->getOwner();
|
||||
add_quantize_op(arg->getLoc(), arg->getType(), arg_block,
|
||||
auto* arg_block = arg.getOwner();
|
||||
add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
|
||||
std::next(arg_block->begin(), i), arg, i);
|
||||
}
|
||||
|
||||
|
@ -38,17 +38,17 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -115,7 +115,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We don't want to insert quantize/dequantize if the quantize op exists.
|
||||
auto res = tf_op.outputs();
|
||||
if (!res->hasOneUse() || isa<QuantizeOp>(*res->user_begin()))
|
||||
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin()))
|
||||
return this->matchFailure();
|
||||
|
||||
// Extract the min/max constant values from the operands. We also consider
|
||||
@ -123,9 +123,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
// constants and the tf.FakeQuantWithMinMaxVarsOp.
|
||||
Value min = tf_op.min(), max = tf_op.max();
|
||||
DenseFPElementsAttr min_value, max_value;
|
||||
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min->getDefiningOp()))
|
||||
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
|
||||
min = id1.input();
|
||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max->getDefiningOp()))
|
||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
|
||||
max = id2.input();
|
||||
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
|
||||
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
|
||||
@ -133,7 +133,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
int quant_dim = -1;
|
||||
if (PerAxis) {
|
||||
// This is a special case that the quant_dim is the last dimensions.
|
||||
quant_dim = res->getType().template cast<ShapedType>().getRank() - 1;
|
||||
quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
|
||||
}
|
||||
// Use the min/max from the operands and the num_bits and narrow_range
|
||||
// attribute to create the quantization parameter for the new quantize op.
|
||||
@ -155,7 +155,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
tf_op.getLoc(), qtype.getValue(), value, qtype);
|
||||
auto dequantize = rewriter.create<TFL::DequantizeOp>(
|
||||
tf_op.getLoc(), res_type, quantize.output());
|
||||
value->replaceAllUsesWith(dequantize);
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
|
||||
return this->matchSuccess();
|
||||
@ -240,7 +240,7 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
// that we can extract info from the shape (e.g., for constructing bias
|
||||
// tensor, for setting depth_multiplier attribute, etc.).
|
||||
auto filter_type =
|
||||
tf_op.filter()->getType().template dyn_cast<RankedTensorType>();
|
||||
tf_op.filter().getType().template dyn_cast<RankedTensorType>();
|
||||
if (filter_type && filter_type.getRank() == 4)
|
||||
return matchSuccess(std::move(state));
|
||||
|
||||
@ -262,7 +262,7 @@ struct ConvertTFConvOp : public RewritePattern {
|
||||
|
||||
// Get a splat zero tensor with the expected dimension for the bias tensor
|
||||
auto filter = tf_op.filter();
|
||||
auto filter_type = filter->getType().template cast<RankedTensorType>();
|
||||
auto filter_type = filter.getType().template cast<RankedTensorType>();
|
||||
auto elem_type = filter_type.getElementType();
|
||||
auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
|
||||
filter_type.getShape());
|
||||
@ -323,7 +323,7 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
||||
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
|
||||
|
||||
// Create tensor type for the transpose result.
|
||||
auto filter_type = filter->getType().cast<RankedTensorType>();
|
||||
auto filter_type = filter.getType().cast<RankedTensorType>();
|
||||
auto result_shape = functional::map(
|
||||
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
|
||||
perm);
|
||||
@ -356,7 +356,7 @@ class ConvertTFDepthwiseConv2dNative
|
||||
// have a corresponding 'depth_multiplier' attribute; the multiplier is the
|
||||
// fourth dimension in the 4-D filter tensor. We query the multiplier from
|
||||
// tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
|
||||
auto multiplier = filter->getType().cast<RankedTensorType>().getDimSize(3);
|
||||
auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
|
||||
|
||||
filter = legalizeFilter(rewriter, loc, filter);
|
||||
return rewriter.create<TFL::DepthwiseConv2DOp>(
|
||||
@ -380,7 +380,7 @@ class ConvertTFDepthwiseConv2dNative
|
||||
/// RankedTensorType.
|
||||
Value legalizeFilter(PatternRewriter &rewriter, Location loc,
|
||||
Value filter) const {
|
||||
auto filter_type = filter->getType().cast<RankedTensorType>();
|
||||
auto filter_type = filter.getType().cast<RankedTensorType>();
|
||||
auto filterShape = filter_type.getShape();
|
||||
SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
|
||||
filterShape[2] * filterShape[3]};
|
||||
@ -432,11 +432,11 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
// Insert a new reshape op.
|
||||
Value original_input = strided_slice_op.input();
|
||||
RankedTensorType original_input_type =
|
||||
original_input->getType().cast<RankedTensorType>();
|
||||
original_input.getType().cast<RankedTensorType>();
|
||||
const ArrayRef<int64_t> &original_input_shape =
|
||||
original_input_type.getShape();
|
||||
RankedTensorType begin_type =
|
||||
strided_slice_op.begin()->getType().cast<RankedTensorType>();
|
||||
strided_slice_op.begin().getType().cast<RankedTensorType>();
|
||||
const int dim_size = begin_type.getShape()[0];
|
||||
SmallVector<int64_t, 4> new_shape;
|
||||
int mask = 1;
|
||||
|
@ -19,17 +19,17 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
@ -18,24 +18,24 @@ limitations under the License.
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
|
||||
|
||||
@ -83,7 +83,7 @@ LogicalResult DuplicateValueIfNeeded(Operation* op,
|
||||
// We can only clone the constant op at this point.
|
||||
// Since all ops have been legalized to tflite ops, so we only care about
|
||||
// ConstOp or QConstOp or mlir constant op/
|
||||
Operation* input_op = operand->getDefiningOp();
|
||||
Operation* input_op = operand.getDefiningOp();
|
||||
if (input_op == nullptr) return failure();
|
||||
|
||||
Attribute attr;
|
||||
|
@ -20,13 +20,13 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
||||
// The cmd line flag to specify the whitelist of functions. Rest are trimmed
|
||||
|
@ -24,17 +24,17 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -83,7 +83,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
|
||||
template <typename BatchMatMulOpType>
|
||||
std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
|
||||
Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
|
||||
RankedTensorType tensorType = value->getType().cast<RankedTensorType>();
|
||||
RankedTensorType tensorType = value.getType().cast<RankedTensorType>();
|
||||
Type element_type = tensorType.getElementType();
|
||||
|
||||
int rank = tensorType.getShape().size();
|
||||
@ -127,7 +127,7 @@ std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
|
||||
template <typename BatchMatMulOpType>
|
||||
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
|
||||
Value value, Location loc, PatternRewriter& rewriter) {
|
||||
auto value_type = value->getType().cast<RankedTensorType>();
|
||||
auto value_type = value.getType().cast<RankedTensorType>();
|
||||
auto shape = value_type.getShape();
|
||||
int dims = shape.size();
|
||||
|
||||
@ -197,17 +197,17 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
||||
Value input_lhs = op.x();
|
||||
Value input_rhs = op.y();
|
||||
|
||||
if (!input_lhs->getType().isa<RankedTensorType>()) {
|
||||
if (!input_lhs.getType().isa<RankedTensorType>()) {
|
||||
// LHS must be a ranked tensor type
|
||||
return this->matchFailure();
|
||||
}
|
||||
if (!input_rhs->getType().isa<RankedTensorType>()) {
|
||||
if (!input_rhs.getType().isa<RankedTensorType>()) {
|
||||
// RHS must be a ranked tensor type
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto lhs_type = input_lhs->getType().cast<RankedTensorType>();
|
||||
auto rhs_type = input_rhs->getType().cast<RankedTensorType>();
|
||||
auto lhs_type = input_lhs.getType().cast<RankedTensorType>();
|
||||
auto rhs_type = input_rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
auto element_type = lhs_type.getElementType();
|
||||
|
||||
@ -233,7 +233,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
||||
if (op.adj_x()) {
|
||||
input_lhs = createTransposeOp(input_lhs, loc, rewriter);
|
||||
|
||||
lhs_type = input_lhs->getType().cast<RankedTensorType>();
|
||||
lhs_type = input_lhs.getType().cast<RankedTensorType>();
|
||||
lhs_shape = lhs_type.getShape();
|
||||
}
|
||||
|
||||
@ -241,7 +241,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
||||
if (op.adj_y()) {
|
||||
input_rhs = createTransposeOp(input_rhs, loc, rewriter);
|
||||
|
||||
rhs_type = input_rhs->getType().cast<RankedTensorType>();
|
||||
rhs_type = input_rhs.getType().cast<RankedTensorType>();
|
||||
rhs_shape = rhs_type.getShape();
|
||||
}
|
||||
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
@ -15,9 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user